package ru.yandex.msearch.util;

import java.io.IOException;

import java.util.ArrayDeque;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;

import ru.yandex.collection.BlockingBlockingQueue;

import ru.yandex.concurrent.LifoWaitBlockingQueue;
import ru.yandex.concurrent.NamedThreadFactory;
import ru.yandex.concurrent.TimeFrameQueue;

import ru.yandex.stater.Stater;
import ru.yandex.stater.StatsConsumer;

import ru.yandex.util.timesource.TimeSource;

public class SleepingIOScheduler extends IOScheduler {
    public static final int DEFAULT_PRIORITY_DIVISOR = 100;
    private static final int IOSTATS_SAMPLING_TIME = 100;
    private static final int NANOS_PER_MICRO = 1000;
    private static final int MICRO_PER_MILLIS = 1000;

    private final int threadCount;
    private final int priorityDivisor;
    private final boolean sleeping;
    private final PriorityBlockingQueue<IOTask> queue;
    private SleepingScheduler scheduler;
    private ThreadPoolExecutor iopool;

    public SleepingIOScheduler(final int threadCount) {
        this(threadCount, DEFAULT_PRIORITY_DIVISOR, true);
    }

    public SleepingIOScheduler(
        final int threadCount,
        final int priorityDivisor,
        final boolean sleeping)
    {
        this.threadCount = threadCount;
        this.priorityDivisor = priorityDivisor;
        this.sleeping = sleeping;
        this.queue = new PriorityBlockingQueue<>();
    }

    @Override
    protected void start() {
        iopool = new ThreadPoolExecutor(
            threadCount,
            threadCount,
            1,
            TimeUnit.DAYS,
            new BlockingBlockingQueue<>(
                new LifoWaitBlockingQueue<>(Math.max(threadCount, 2))),
            new NamedThreadFactory("IO") {
                @Override
                public Thread newThread(final Runnable r) {
                    return super.newThread(() -> {
                        Compress.setThreadIOPriority(
                            Compress.IOPRIO_CLASS_BE,
                            Compress.IOPRIO_HIGH);
                        r.run();
                    });
                }
            },
            new ThreadPoolExecutor.CallerRunsPolicy());
        scheduler = new SleepingScheduler(queue, iopool, threadCount, sleeping);
        scheduler.setDaemon(true);
        scheduler.start();
    }

    private <T> T ioOp(final int prio, final Callable<T> op, final boolean read)
        throws IOException
    {
        try {
            IOTask<T> task = new IOTask<T>(op, prio, priorityDivisor, read);
            queue.offer(task);
            return task.get();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IOException("Interrupted IO", e);
        } catch (ExecutionException e) {
            if (e.getCause() instanceof IOException) {
                throw (IOException) e.getCause();
            } else {
                throw new IOException("IO error", e);
            }
        }
    }

    @Override
    public <T> T readOp(final int prio, final Callable<T> op)
        throws IOException
    {
        return ioOp(prio, op, true);
    }

    @Override
    public <T> T writeOp(final int prio, final Callable<T> op)
        throws IOException
    {
        return ioOp(prio, op, false);
    }

    @Override
    public Stater stater(final long metricsTimeFrame) {
        scheduler.initStater(metricsTimeFrame);
        return scheduler;
    }

    private static class SleepingScheduler extends Thread implements Stater {
        private final PriorityBlockingQueue<IOTask> queue;
        private final ThreadPoolExecutor iopool;
        private final HashMap<Integer, IOStats> iostats;
        private final int threadCount;
        private final boolean sleeping;
        private TimeFrameQueue<ImmutableIOStats> staterQueue = null;

        SleepingScheduler(
            final PriorityBlockingQueue<IOTask> queue,
            final ThreadPoolExecutor iopool,
            final int threadCount,
            final boolean sleeping)
        {
            super("IOScheduler");
            this.queue = queue;
            this.iopool = iopool;
            this.threadCount = threadCount;
            this.sleeping = sleeping;
            iostats = new HashMap<>();
        }

        public void initStater(final long metricsTimeFrame) {
            this.staterQueue = new TimeFrameQueue<>(metricsTimeFrame);
        }

        @Override
        public void run() {
            int prevTaskPrio = 0;
            ArrayDeque<IOTask> deffered = new ArrayDeque<>();
            long iostatSampleStart = TimeSource.INSTANCE.currentTimeMillis();
            while (true) {
                try {
                    IOTask task =
                        queue.poll(
                            IOSTATS_SAMPLING_TIME,
                            TimeUnit.MILLISECONDS);
                    long currentTime = TimeSource.INSTANCE.currentTimeMillis();
                    long timeDiff = currentTime - iostatSampleStart;
                    if (timeDiff > IOSTATS_SAMPLING_TIME
                        && staterQueue != null)
                    {
                        iostatSampleStart = currentTime;
                        for (IOStats stat: iostats.values()) {
                            staterQueue.accept(stat.snapshotAndReset());
                        }
                        staterQueue.accept(
                            new ImmutableIOStats(
                                timeDiff * MICRO_PER_MILLIS * threadCount));
                    }
                    if (task == null) {
                        continue;
                    }
                    long start = System.nanoTime();
                    long elapsed = 0;
                    long penalty = task.priority();
                    while (sleeping
                        && prevTaskPrio < task.priority() && elapsed < penalty)
                    {
                        //wait for a better task
                        if (task.priority() < penalty) {
                            //task is better, but not enought
                            //decrease penalty
                            penalty = task.priority();
                        }
                        deffered.add(task);
                        task = queue.poll(
                            Math.max(penalty - elapsed, 1), TimeUnit.MICROSECONDS);
                        elapsed += (System.nanoTime() - start) / 1000;
                        if (task == null) {
                            task = deffered.removeLast();
                        }
                    }
                    if (deffered.size() > 0) {
                        //resort tasks
                        queue.addAll(deffered);
                        queue.add(task);
                        deffered.clear();
                        //execute top
                        task = queue.poll();
                    }
                    prevTaskPrio = task.priority();
                    IOStats iostat = iostats.get(task.priority());
                    if (iostat == null) {
                        iostat = new IOStats(task.priority());
                        iostats.put(task.priority(), iostat);
                    }
                    task.iostat(iostat);
                    iopool.execute(task);
                } catch (Throwable t) {
                    t.printStackTrace();
                }
            }
        }

        @Override
        public <E extends Exception> void stats(
            final StatsConsumer<? extends E> statsConsumer)
            throws E
        {
            long totalIops = 0L;
            long totalReads = 0L;
            long totalWrites = 0L;
            long totalReadTime = 0L;
            long totalWriteTime = 0L;
            long totalTime = 0L;

            HashMap<Integer, IOStats> prioStats = new HashMap<>();
            for (ImmutableIOStats stat: staterQueue) {
                IOStats statsSum = prioStats.get(stat.prio());
                if (statsSum == null) {
                    statsSum = new IOStats(stat.prio());
                    prioStats.put(stat.prio(), statsSum);
                }
                statsSum.add(stat);
                totalTime += stat.totalTime();
            }
            for (IOStats stat: prioStats.values()) {
                long reads = stat.reads();
                long writes = stat.writes();
                long readTime = stat.readTime();
                long writeTime = stat.writeTime();

                totalReads += reads;
                totalWrites += writes;
                totalReadTime += readTime;
                totalWriteTime += writeTime;
                int prio = stat.prio();
                statsConsumer.stat(
                    "iosched-reads-prio-" + prio + "_ammm",
                    reads);
                statsConsumer.stat(
                    "iosched-writes-prio-" + prio + "_ammm",
                    writes);
                statsConsumer.stat(
                    "iosched-read-time-prio-" + prio + "_ammm",
                    readTime);
                statsConsumer.stat(
                    "iosched-write-time-prio-" + prio + "_ammm",
                    writeTime);
            }
            totalIops = totalReads + totalWrites;
            statsConsumer.stat("iosched-iops-total_ammm", totalIops);
            statsConsumer.stat("iosched-reads-total_ammm", totalReads);
            statsConsumer.stat("iosched-writes-total_ammm", totalWrites);
            statsConsumer.stat("iosched-read-time-total_ammm", totalReadTime);
            statsConsumer.stat(
                "iosched-write-time-total_ammm",
                totalWriteTime);
            long totalIopsTime = totalReadTime + totalWriteTime;
            totalTime = Math.max(totalIopsTime, totalTime);
            long idle = totalTime - totalIopsTime;
            statsConsumer.stat("iosched-idle-time-total_ammm", idle);
            statsConsumer.stat("iosched-iops-time-total_ammm", totalIopsTime);
            statsConsumer.stat("iosched-sched-time-total_ammm", totalTime);
        }
    }

    private static class IOTask<T> extends FutureTask<T>
        implements Comparable<IOTask<T>>
    {
        private final int priority;
        private final long correctedPriority;
        private final boolean read;
        private IOStats iostat;

        IOTask(
            final Callable<T> callable,
            final int priority,
            final int priorityDivisor,
            final boolean read)
        {
            super(callable);
            correctedPriority = TimeSource.INSTANCE.currentTimeMillis()
                + (priority / priorityDivisor);
            this.priority = priority;
            this.read = read;
        }

        public void iostat(final IOStats iostat) {
            this.iostat = iostat;
        }

        public int priority() {
            return priority;
        }

        @Override
        public void run() {
            final long start = System.nanoTime();
            super.run();
            final long timeDiff = System.nanoTime() - start;
            if (iostat != null) {
                iostat.op(read, timeDiff / NANOS_PER_MICRO);
            }
        }

        @Override
        public int compareTo(final IOTask<T> other) {
            return Long.compare(correctedPriority, other.correctedPriority);
        }
    }

    private final static class IOStats {
        private final LongAdder reads = new LongAdder();
        private final LongAdder writes = new LongAdder();
        private final LongAdder readTime = new LongAdder();
        private final LongAdder writeTime = new LongAdder();
        private final int prio;

        public IOStats(final int prio) {
            this.prio = prio;
        }

        public void op(final boolean read, final long time) {
            if (read) {
                reads.increment();
                readTime.add(time);
            } else {
                writes.increment();
                writeTime.add(time);
            }
        }

        public void add(final ImmutableIOStats other) {
            reads.add(other.reads());
            writes.add(other.writes());
            readTime.add(other.readTime());
            writeTime.add(other.writeTime());
        }

        public int prio() {
            return prio;
        }

        public long reads() {
            return reads.sumThenReset();
        }

        public long writes() {
            return writes.sumThenReset();
        }

        public long readTime() {
            return readTime.sumThenReset();
        }

        public long writeTime() {
            return writeTime.sumThenReset();
        }

        public ImmutableIOStats snapshotAndReset() {
            return new ImmutableIOStats(this);
        }
    }

    private final static class ImmutableIOStats {
        private final long reads;
        private final long writes;
        private final long readTime;
        private final long writeTime;
        private final long totalTime;
        private final int prio;

        public ImmutableIOStats(final IOStats stat) {
            this.reads = stat.reads();
            this.writes = stat.writes();
            this.readTime = stat.readTime();
            this.writeTime = stat.writeTime();
            this.totalTime = 0;
            this.prio = stat.prio();
        }

        public ImmutableIOStats(final long totalTime) {
            this.reads = 0;
            this.writes = 0;
            this.readTime = 0;
            this.writeTime = 0;
            this.totalTime = totalTime;
            this.prio = -1;
        }

        public int prio() {
            return prio;
        }

        public long reads() {
            return reads;
        }

        public long writes() {
            return writes;
        }

        public long readTime() {
            return readTime;
        }

        public long writeTime() {
            return writeTime;
        }

        public long totalTime() {
            return totalTime;
        }
    }
}
