package ru.yandex.market.graphouse.server;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.ContextClosedEvent;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;

import ru.yandex.market.graphouse.cacher.MetricCacher;
import ru.yandex.market.graphouse.cacher.MetricCacherProvider;
import ru.yandex.market.graphouse.search.MetricAddException;
import ru.yandex.market.graphouse.search.MetricResponseStatus;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.GaugeInt64;
import ru.yandex.monlib.metrics.primitives.Histogram;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.config.protobuf.graphite.storage.TMetricServer;
import ru.yandex.solomon.config.thread.ThreadPoolProvider;

/**
 * @author Dmitry Andreev <a href="mailto:AndreevDm@yandex-team.ru"/>
 * @date 02/04/15
 */
@Component
public class MetricServer {

    private static final Logger log = LoggerFactory.getLogger(MetricServer.class);

    private final MetricCacherProvider metricCacherProvider;
    private final MetricFactory metricFactory;
    private final MetricLimbo metricLimbo;
    private final ExecutorService executorService;
    private final int maxMetricsForReject;
    private final int limitPushMetrics;
    private final int limitForNewMetrics;

    private static final int cacherCount = 4;
    private static final int port = 2024;
    private static final int socketTimeoutMillis = 15_000;
    private static final int threadCount = 300;
    private static final int PUSH_BATCH_SIZE = 20_000;

    private volatile boolean isRunning = true;

    private final ServerSocket serverSocket;
    private final ServerMetrics serverMetrics;
    private AtomicInteger busyWorkers = new AtomicInteger();
    private AtomicLong connectionsAccepted = new AtomicLong();

    @Autowired
    public MetricServer(
        MetricCacherProvider metricCacherProvider,
        MetricFactory metricFactory,
        MetricLimbo metricLimbo,
        TMetricServer config,
        ThreadPoolProvider threadPoolProvider) throws Exception
    {
        this.metricCacherProvider = metricCacherProvider;
        this.metricFactory = metricFactory;
        this.metricLimbo = metricLimbo;
        this.executorService = threadPoolProvider.getExecutorService(config.getThreadPoolName(), "MetricServer.ThreadPoolName");
        this.limitPushMetrics = config.getLimitPushMetrics();
        this.limitForNewMetrics = config.getLimitForNewMetrics();
        maxMetricsForReject = config.getMaxMetricsForReject();
        log.info("Starting metric server on port: " + port);
        serverSocket = new ServerSocket(port);
        serverSocket.setSoTimeout(socketTimeoutMillis); // to make listeners to periodically check if isRunning

        log.info("Starting " + threadCount + " metric server threads");
        MetricCacher[] metricCachers = makeCachers();
        for (int i = 0; i < threadCount; i++) {
            executorService.submit(new MetricServerWorker(metricCachers[i % cacherCount]));
        }
        serverMetrics = new ServerMetrics(MetricRegistry.root());
        serverMetrics.setTotalMetricWorkers(threadCount);
        log.info("Metric server started");
    }

    @EventListener(ContextClosedEvent.class)
    public void shutdownHandler() {
        log.info("Shutting down metric server");
        {
            // Here we want all existing connections to shutdown
            // and to prevent all new connections from establishing
            isRunning = false;

            executorService.shutdownNow();
            try {
                // worst case scenario
                // 1 timeout for 1 thread who just started accept()ing  to actually accept
                // 1000 ms for it to receive some data
                // 1 more timeout for it to wait for client who is not going to close this connection
                int totalTimeout = 2 * socketTimeoutMillis + 1000;
                executorService.awaitTermination(totalTimeout, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e) {
                /* ignored */
            }
            // TODO: check what happens on balancer when we have a socket, but do no accept() it
            IOUtils.closeQuietly(serverSocket);
        }
        log.info("Metric server does not accept metrics anymore");
        metricCacherProvider.shutdown();
        log.info("Metric server stopped");
    }

    private MetricCacher[] makeCachers() {
        MetricCacher[] metricCachers = new MetricCacher[cacherCount];
        for (int i = 0; i < cacherCount; i++) {
            metricCachers[i] = metricCacherProvider.makeCacher("MetricServer-" + i);
        }
        return metricCachers;
    }

    private class MetricServerWorker implements Runnable {

        private final MetricCacher metricCacher;

        private MetricServerWorker(MetricCacher metricCacher) {
            this.metricCacher = metricCacher;
        }

        @Override
        public void run() {
            while (isRunning) {
                if (metricFactory.isReady()) {
                    break;
                } else {
                    try {
                        Thread.sleep(300);
                    } catch (InterruptedException e) {
                        return;
                    }
                }
            }
            while (isRunning && !serverSocket.isClosed()) {
                try {
                    read();
                } catch (SocketTimeoutException ste) {
                    // thrown from accept(): this is by design - we have to check isRunning periodically
                } catch (Exception e) {
                    log.warn("Failed to read data", e);
                }
            }
            log.info("MetricServerWorker stopped");
        }

        private void read() throws IOException {
            Socket socket = serverSocket.accept();
            busyWorkers.incrementAndGet();
            long connectionId = connectionsAccepted.incrementAndGet();
            log.info("Accepted connection " + connectionId);
            int metricsInConnection = 0;
            long startTimeMillis = System.currentTimeMillis();

            long readTimeMillis = 0;
            long searchTimeMillis = 0;
            long insertTimeMillis = 0;

            long metricStartProcessTimeMillis;
            long metricEndProcessTimeMillis = startTimeMillis;

            String ipAddress = socket.getRemoteSocketAddress().toString();
            try {
                socket.setSoTimeout(socketTimeoutMillis);
                socket.setKeepAlive(false);
                BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
                String line;

                MetricBatch resolvedMetrics = new MetricBatch(100); // Most connections do not have enough metrics to fill one batch
                List<UnresolvedMetric> unresolvedMetrics = UnresolvedMetric.f.newInstance();

                long metricEndSearchTimeMillis;

                while ((line = reader.readLine()) != null) {
                    metricsInConnection++;
                    if (metricsInConnection > maxMetricsForReject) {
                        break;
                    }
                    metricStartProcessTimeMillis = System.currentTimeMillis();
                    readTimeMillis += metricStartProcessTimeMillis - metricEndProcessTimeMillis;

                    try {
                        AnyMetric metric = metricFactory.createMetric(line, ipAddress);
                        metricEndSearchTimeMillis = System.currentTimeMillis();

                        if (metric.isResolved()) {
                            resolvedMetrics.add((ResolvedMetric) metric);
                            if (resolvedMetrics.size() == PUSH_BATCH_SIZE) {
                                metricCacher.submitMetrics(resolvedMetrics);
                                resolvedMetrics = new MetricBatch(PUSH_BATCH_SIZE);
                            }
                        } else {
                            unresolvedMetrics.add((UnresolvedMetric) metric);
                            if (unresolvedMetrics.size() == PUSH_BATCH_SIZE) {
                                metricLimbo.submitMetrics(unresolvedMetrics);
                                log.warn("ip:{} pushed {} new metrics", socket.getRemoteSocketAddress(), unresolvedMetrics.size());
                                unresolvedMetrics = UnresolvedMetric.f.newInstance();
                            }
                        }
                        serverMetrics.recordStatus(MetricResponseStatus.IN_FLIGHT_QUEUE);
                    } catch (MetricAddException e) {
                        metricEndSearchTimeMillis = System.currentTimeMillis();
                        log.info(e.status + ": [" + line + "]");
                        serverMetrics.recordStatus(e.status);
                    }

                    searchTimeMillis += (metricEndSearchTimeMillis - metricStartProcessTimeMillis);
                    metricEndProcessTimeMillis = System.currentTimeMillis();
                    insertTimeMillis += (metricEndProcessTimeMillis - metricEndSearchTimeMillis);
                }
                if (!resolvedMetrics.isEmpty()) {
                    metricCacher.submitMetrics(resolvedMetrics);
                }
                if (!unresolvedMetrics.isEmpty()) {
                    metricLimbo.submitMetrics(unresolvedMetrics);
                    if (unresolvedMetrics.size() > limitForNewMetrics) {
                        log.warn("ip:{} pushed {} new metrics", socket.getRemoteSocketAddress(), unresolvedMetrics.size());
                    }
                }
                insertTimeMillis += (System.currentTimeMillis() - metricEndProcessTimeMillis);
            } catch (SocketTimeoutException e) {
                log.warn("Socket timeout from " + ipAddress);
                serverMetrics.recordTimeoutMillis(System.currentTimeMillis() - metricEndProcessTimeMillis);
            } catch (SocketException e) {
                if (e.getMessage().equals("Connection reset")) {
                    // Most likely this is our balancer <3
                    log.warn("Connection reset from " + ipAddress);
                } else {
                    throw e;
                }
            } finally {
                long endTimeMillis = System.currentTimeMillis();
                serverMetrics.recordConnectionDurationHistogram(endTimeMillis - startTimeMillis);
                log.info("ip:{} released connection:{} cached:{} metrics in {} ms " +
                        "({} ms reads; {} ms search; {} ms insert) Using metric cacher:{}",
                    ipAddress,
                    connectionId,
                    metricsInConnection,
                    endTimeMillis - startTimeMillis,
                    readTimeMillis,
                    searchTimeMillis,
                    insertTimeMillis,
                    metricCacher.id);
                if (metricsInConnection > limitPushMetrics) {
                    log.warn("ip:{} pushed too many metrics:{}", ipAddress, metricsInConnection);
                }
                serverMetrics.recordReadMillis(readTimeMillis);
                serverMetrics.recordSearchMillis(searchTimeMillis);
                serverMetrics.recordInsertMillis(insertTimeMillis);
                busyWorkers.decrementAndGet();
                socket.close();
            }
        }
    }

    private class ServerMetrics {

        private final GaugeInt64 totalMetricWorkers;
        private final Histogram connectionDurationHistogram;
        private final Histogram readMillis;
        private final Histogram searchMillis;
        private final Histogram insertMillis;
        private final Histogram timeoutMillis;
        private final Map<MetricResponseStatus, Rate> writeStatus;

        ServerMetrics(MetricRegistry metricRegistry) {
            metricRegistry.lazyGaugeInt64("server.busyMetricWorkers", () -> busyWorkers.get());
            this.totalMetricWorkers = metricRegistry.gaugeInt64("server.totalMetricWorkers");
            metricRegistry.lazyGaugeInt64("server.connectionsAccepted", () -> connectionsAccepted.get());
            this.connectionDurationHistogram =
                metricRegistry.histogramRate("server.connectionDurationHistogram", Histograms.exponential(11, 2, 64));
            this.readMillis =
                metricRegistry.histogramRate("server.readMillis", Histograms.exponential(11, 2, 64));
            this.searchMillis =
                metricRegistry.histogramRate("server.searchMillis", Histograms.exponential(11, 2, 64));
            this.insertMillis =
                metricRegistry.histogramRate("server.insertMetabaseMillis", Histograms.exponential(11, 2, 64));
            this.timeoutMillis =
                metricRegistry.histogramRate("server.timeoutMillis", Histograms.exponential(11, 2, 64));
            this.writeStatus = new HashMap<>();
            for (MetricResponseStatus status : MetricResponseStatus.values()) {
                writeStatus.put(status, metricRegistry.rate("writeStatus", Labels.of("status", status.name())));
            }
        }

        void setTotalMetricWorkers(long totalMetricWorkers) {
            this.totalMetricWorkers.set(totalMetricWorkers);
        }

        public void recordConnectionDurationHistogram(long time) {
            this.connectionDurationHistogram.record(time);
        }

        public void recordReadMillis(long time) {
            this.readMillis.record(time);
        }

        public void recordSearchMillis(long time) {
            this.searchMillis.record(time);
        }

        public void recordInsertMillis(long time) {
            this.insertMillis.record(time);
        }

        public void recordTimeoutMillis(long time) {
            this.timeoutMillis.record(time);
        }

        void recordStatus(MetricResponseStatus status) {
            writeStatus.get(status).inc();
        }

    }
}
