package ru.yandex.direct.ess.router.components;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.ess.router.config.LogbrokerWriterAdditionalConfig;
import ru.yandex.direct.ess.router.models.rule.RuleProcessingResult;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.Counter;
import ru.yandex.monlib.metrics.primitives.GaugeInt64;
import ru.yandex.monlib.metrics.registry.MetricRegistry;

public abstract class RouterMonitoring {

    private final Logger logger = LoggerFactory.getLogger(getClass());

    // Labels
    private static final String SHARD_LABEL = "shard";
    private static final String TOPIC_LABEL = "topic_label";

    // Metrics
    protected final Map<String, Long> maxTimestamps = new ConcurrentHashMap<>();
    private final Map<String, Counter> rowsCount = new ConcurrentHashMap<>();
    private final Map<String, Counter> processedLogicObjects = new ConcurrentHashMap<>();
    private final Map<String, Counter> additionalLogicObjects = new ConcurrentHashMap<>();
    private final Map<String, Counter> skippedAdditionalLogicObjects = new ConcurrentHashMap<>();
    private final Map<LogbrokerWriterAdditionalConfig, Counter> writingMessages = new ConcurrentHashMap<>();
    private final GaugeInt64 writingDurationMetric;

    private final MetricRegistry metricRegistry;

    public RouterMonitoring(MetricRegistry metricRegistry) {
        this.metricRegistry = metricRegistry;
        this.writingDurationMetric = metricRegistry.gaugeInt64(getWritingDurationMetricName());
    }

    void addLogicObjectsMetrics(RuleProcessingResult ruleProcessingResult) {
        ruleProcessingResult.getGroupedStatByPartitions().forEach((partition, stat) -> {
                    addProcessedLogicObjects(partition, stat.getProcessedObjectsCnt());
                    addAdditionalLogicObjects(partition, stat.getAdditionalObjectsCnt());
                    addSkippedAdditionalLogicObjects(partition, stat.getAdditionalSkippedCnt());
                }
        );
    }

    void addProcessedLogicObjects(int partition, long count) {
        addCountMetric(processedLogicObjects, partition, count, getProcessedLogicObjectsMetricName());
    }

    void addAdditionalLogicObjects(int partition, long count) {
        addCountMetric(additionalLogicObjects, partition, count, getAdditionalLogicObjectsMetricName());
    }

    void addSkippedAdditionalLogicObjects(int partition, long count) {
        addCountMetric(skippedAdditionalLogicObjects, partition, count, getSkippedAdditionalLogicObjectsMetricName());
    }

    private void addCountMetric(Map<String, Counter> metricsMap, int partition, long count, String labelName) {
        String shard = String.valueOf(partition + 1);
        metricsMap.computeIfAbsent(shard, s -> metricRegistry.counter(labelName, Labels.of(SHARD_LABEL, shard)))
                .add(count);
    }

    void updateMetrics(Map<Integer, PartitionMetrics> partitionToMetrics, long startWritingTimestamp) {
        updateDelays(partitionToMetrics);
        updateRowsCounts(partitionToMetrics);
        addWritingDuration(System.currentTimeMillis() / 1000 - startWritingTimestamp);
    }

    void updateDelays(Map<Integer, PartitionMetrics> partitionToMetrics) {
        long now = System.currentTimeMillis() / 1000;
        partitionToMetrics.forEach(
                (k, v) -> {
                    long delay = now - v.maxTimestamp;
                    logger.info("Router delay for partition {}: {} sec., max timestamp: {}", k, delay,
                            v.maxTimestamp);
                    setTimestamp(k, v.maxTimestamp);
                }
        );
    }

    void setTimestamp(int partition, long timestamp) {
        String shard = String.valueOf(partition + 1);

        Long previousValue = maxTimestamps.put(shard, timestamp);
        if (previousValue == null) {
            metricRegistry.lazyGaugeInt64(
                    getDelayMetricName(),
                    Labels.of(SHARD_LABEL, shard),
                    () -> calculateDelayForShard(shard)
            );
        }
    }

    private long calculateDelayForShard(String shard) {
        return calculateDelay(maxTimestamps.get(shard));
    }

    private long calculateDelay(long timestamp) {
        long now = System.currentTimeMillis() / 1000;
        return now - timestamp;
    }

    void updateRowsCounts(Map<Integer, PartitionMetrics> partitionToMetrics) {
        partitionToMetrics
                .forEach((k, v) -> {
                    logger.info("Router handling rows for partition {}: {}", k, v.rowsCount);
                    addRowsCount(k, v.rowsCount);
                });
    }

    void addRowsCount(int partition, long count) {
        addCountMetric(rowsCount, partition, count, getRowCountMetricName());
    }

    void addWritingMessages(LogbrokerWriterAdditionalConfig logbrokerWriterAdditionalConfig, int count) {
        writingMessages.computeIfAbsent(logbrokerWriterAdditionalConfig,
                        conf -> metricRegistry.counter(getWritingMessagesMetricName(), Labels.of(SHARD_LABEL,
                                conf.getGroup().toString(), TOPIC_LABEL, conf.getTopic())))
                .add(count);
    }

    void addWritingDuration(long writingDuration) {
        writingDurationMetric.set(writingDuration);
    }

    abstract String getDelayMetricName();

    abstract String getProcessedLogicObjectsMetricName();

    abstract String getAdditionalLogicObjectsMetricName();

    abstract String getSkippedAdditionalLogicObjectsMetricName();

    abstract String getRowCountMetricName();

    abstract String getWritingMessagesMetricName();

    abstract String getWritingDurationMetricName();
}
