package ru.yandex.market.graphouse.cacher;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.market.graphouse.server.MetricBatch;
import ru.yandex.market.graphouse.stockpile.GraphouseStockpileClient;
import ru.yandex.market.graphouse.stockpile.StockpilePushBatch;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.util.concurrent.ThreadUtils;
import ru.yandex.stockpile.client.writeRequest.StockpileShardWriteRequest;


/**
 * @author Maksim Leonov (nohttp@)
 */
public class MetricCacher implements Runnable {

    private static final Logger log = LoggerFactory.getLogger(MetricCacher.class);
    private static final int writersCount = 16;

    public final String id;
    private final int cacheSize = 4_000_000;
    private final int batchSize = 1_000_000;
    private final int flushIntervalSeconds = 1;

    private final GraphouseStockpileClient client;

    private final Semaphore semaphore = new Semaphore(0, false);
    private LinkedBlockingQueue<MetricBatch> metricQueue;

    private ExecutorService executorService;

    public MetricCacher(String id, GraphouseStockpileClient client) {
        this.id = id;
        this.client = client;

        this.metricQueue = new LinkedBlockingQueue<>(cacheSize);
        this.semaphore.release(cacheSize);
        this.executorService = Executors.newFixedThreadPool(writersCount, ThreadUtils.newThreadFactory(id));
        MetricRegistry metricRegistry = MetricRegistry.root();
        metricRegistry.lazyGaugeInt64(
            "MetricCacher.queueSize", Labels.of("id", id), this::getQueueSize);
        metricRegistry.lazyGaugeInt64(
            "MetricCacher.maxTimeInQueueMillis", Labels.of("id", id), this::getMaxTimeInQueueMillis);
        new Thread(this, "Metric cacher thread").start();
    }

    public void shutdownHook() {
        log.info("[" + id + "] Shutting down metric cacher. Saving all cached metrics...");
        while (!metricQueue.isEmpty()) {
            log.info(metricQueue.size() + " metric batches remaining");
            writeMetrics();
            try {
                Thread.sleep(100);
            } catch (InterruptedException ignored) {
            }
        }
        executorService.shutdown();
        log.info("[" + id + "] Awaiting save completion");
        while (!executorService.isTerminated()) {
            try {
                executorService.awaitTermination(100, TimeUnit.MILLISECONDS);
            } catch (InterruptedException ignored) {
            }
        }
        log.info("[" + id + "] Metric cacher stopped");
    }

    public void submitMetrics(MetricBatch batch) {
        try {
            semaphore.acquire(batch.size());
            metricQueue.put(batch);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void run() {
        while (!Thread.interrupted()) {
            try {
                Thread.sleep(TimeUnit.SECONDS.toMillis(flushIntervalSeconds));
            } catch (InterruptedException ignored) {
            }
            try {
                writeMetrics();
            } catch (Throwable t) {
                log.warn("couldn't write metrics", t);
            }
        }
    }

    private void writeMetrics() {
        if (metricQueue.isEmpty()) {
            return;
        }
        int queueSize = getQueueSize();
        double queueOccupancyPercent = queueSize * 100.0 / cacheSize;
        log.info(
            "[" + id + "] Metric queue size: " + queueSize + "(" + queueOccupancyPercent + "%)");
        while (true) {
            List<StockpilePushBatch> batches = createBatches();
            for (StockpilePushBatch metrics : batches) {
                if (metrics.size() == 0) {
                    return;
                }
                executorService.execute(new ClickhouseWriterWorker(metrics));
            }
            long size = batches.stream().mapToLong(StockpilePushBatch::size).sum();
            if (size < batchSize) {
                log.info("don't write any metrics because we have only:{}", size);
                return; //Не набрали полный батч, значит больше и не надо
            }
        }
    }

    private List<StockpilePushBatch> createBatches() {
        ArrayList<MetricBatch> buffer = new ArrayList<>();
        metricQueue.drainTo(buffer);

        return client.compress(buffer);
    }

    private long getMaxTimeInQueueMillis() {
        MetricBatch oldestBatch = metricQueue.peek();
        if (oldestBatch != null) {
            return System.currentTimeMillis() - oldestBatch.getSubmittedAtMillis();
        } else {
            return 0;
        }
    }

    private int getQueueSize() {
        return cacheSize - semaphore.availablePermits();
    }

    private class ClickhouseWriterWorker implements Runnable {

        private final StockpilePushBatch metrics;

        public ClickhouseWriterWorker(StockpilePushBatch metrics) {
            this.metrics = metrics;
        }

        @Override
        public void run() {
            boolean ok = false;
            while (!ok) {
                try {
                    long start = System.currentTimeMillis();
                    client.saveMetrics(metrics).join();
                    metrics.close();
                    long processed = System.currentTimeMillis() - start;
                    log.info("Saved " + metrics.size() + " metrics in " + processed + "ms");
                    semaphore.release(metrics.size());
                    ok = true;
                } catch (Exception e) {
                    long memorySize = metrics.getWritesByShard().values().stream()
                        .mapToLong(StockpileShardWriteRequest::memorySizeIncludingSelf)
                        .sum();
                    long recordCount = metrics.getWritesByShard().values().stream()
                        .mapToLong(StockpileShardWriteRequest::countRecords)
                        .sum();
                    log.error("Failed to save " + metrics.size() + " metrics, memorySize=" + memorySize
                        + " recordCount=" + recordCount + " . Waiting 1 second before retry", e);
                    try {
                        Thread.sleep(1000);
                    } catch (InterruptedException ignored) {
                    }
                }
            }
        }

    }
}
