package ru.yandex.direct.useractionlog.writer;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.time.Duration;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import javax.annotation.ParametersAreNonnullByDefault;

import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.config.DirectConfig;
import ru.yandex.direct.db.config.DbConfigFactory;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapper;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapperProvider;
import ru.yandex.direct.dbutil.wrapper.SimpleDb;
import ru.yandex.direct.env.Environment;
import ru.yandex.direct.graphite.GraphiteMetricsBuffer;
import ru.yandex.direct.metric.collector.MetricProvider;
import ru.yandex.direct.useractionlog.LatestRecordTimeFetcher;
import ru.yandex.direct.useractionlog.TableNames;
import ru.yandex.direct.useractionlog.db.ReadActionLogTable;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.GaugeDouble;
import ru.yandex.monlib.metrics.primitives.GaugeInt64;

import static ru.yandex.direct.solomon.SolomonUtils.SOLOMON_REGISTRY;
import static ru.yandex.direct.useractionlog.db.DbConfigUtil.getDbReplicas;
import static ru.yandex.direct.useractionlog.db.DbConfigUtil.getShardReplicas;
import static ru.yandex.direct.useractionlog.db.DbConfigUtil.getWorkingReplica;

@ParametersAreNonnullByDefault
public class WriterMetricProvider implements MetricProvider {
    private static final Logger logger = LoggerFactory.getLogger(WriterMetricProvider.class);
    private static final Duration MAX_DELAY = Duration.ofDays(10);
    private static final Duration OBJECT_COUNT_WINDOW = Duration.ofSeconds(60);
    private static final double GB = Math.pow(2, 30);

    private final DatabaseWrapperProvider wrapperProvider;
    private final Collection<String> allReplicas;
    private final LatestRecordTimeFetcher timeFetcher;
    private final ReadActionLogTable readActionLogTable;
    private final Set<String> shardNames;
    private final Map<String, Pair<LocalDateTime, LocalDateTime>> sourceLastChanged = new HashMap<>();
    private final Duration noValueCriticalDuration;

    public WriterMetricProvider(DirectConfig directConfig, DatabaseWrapperProvider wrapperProvider,
                                DbConfigFactory configFactory,
                                Collection<String> shardNames) {
        // в первом шарде лежит таблица со стейтом
        Collection<String> replicas = getShardReplicas(
                configFactory, SimpleDb.PPCHOUSE_PPC.toString() + ":shards:1");
        Function<String, DatabaseWrapper> fn = ignored -> getWorkingReplica(wrapperProvider, replicas);
        this.wrapperProvider = wrapperProvider;
        this.allReplicas = getDbReplicas(configFactory, SimpleDb.PPCHOUSE_PPC.toString());
        this.timeFetcher = new LatestRecordTimeFetcher(fn);
        this.readActionLogTable = new ReadActionLogTable(fn, TableNames.READ_USER_ACTION_LOG_TABLE);
        this.shardNames = new HashSet<>(shardNames);
        this.noValueCriticalDuration = directConfig.getDuration("alw.no_value_critical_duration");
    }

    @Override
    public void provideMetrics(GraphiteMetricsBuffer buf) {
        // buf не используется в силу отказа от графита
        LocalDateTime now = ZonedDateTime.now(ZoneOffset.UTC).toLocalDateTime();
        addDelayMetrics(now);
        addObjectMetrics(now);
        addTableSizeMetrics();
    }

    private void addDelayMetrics(LocalDateTime now) {
        LocalDate minRecordDate = now.minus(MAX_DELAY).toLocalDate();
        Map<String, LocalDateTime> latestRecordBySource = timeFetcher.getLatestRecordTime(minRecordDate);

        for (Map.Entry<String, LocalDateTime> entry : latestRecordBySource.entrySet()) {
            if (!shardNames.contains(entry.getKey())) {
                continue;
            }
            long delay;
            if (entry.getValue().equals(LocalDateTime.MIN)) {
                // На самом деле записей нет начиная с начала minRecordDate, но если
                // возвращать это время, график будет прыгать при смене minRecordDate
                delay = MAX_DELAY.getSeconds();
            } else {

                // NB: Грязный хак, который прибьёт jvm, если в какой-то шард не будет ничего писаться
                // в течение какого-то времени. Сделан как борьба с дедлоками от keepalive читалки
                // бинлогов. Следует выпилить как только проблема будет полечена.
                // Не будет работать, если в шарде не было записей вот уже MAX_DELAY времени.
                if (Environment.getCached().isProductionOrPrestable()) {
                    Pair<LocalDateTime, LocalDateTime> lastChanged = sourceLastChanged.get(entry.getKey());
                    if (lastChanged == null || !lastChanged.getRight().equals(entry.getValue())) {
                        sourceLastChanged.put(entry.getKey(), Pair.of(now, entry.getValue()));
                    } else {
                        if (Duration.between(lastChanged.getLeft(), now).compareTo(noValueCriticalDuration)
                                > 0) {
                            logger.error(
                                    "{} did not have any objects written for more than {}."
                                            + " Assuming deadlock, shutting down jvm.",
                                    entry.getKey(), noValueCriticalDuration);
                            System.exit(1);
                        }
                    }
                }
                delay = Math.max(ChronoUnit.SECONDS.between(entry.getValue(), now), 0);
            }

            GaugeInt64 delaySensor = SOLOMON_REGISTRY.gaugeInt64(
                    "writer_metric_provider_delay_seconds",
                    Labels.of("source", entry.getKey())
            );
            delaySensor.set(delay);
        }
    }

    private void addObjectMetrics(LocalDateTime now) {
        for (ReadActionLogTable.TypeCount typeCount : readActionLogTable.getCountByTypeBetween(
                now.minus(OBJECT_COUNT_WINDOW), now)) {

            GaugeInt64 objectsSensor = SOLOMON_REGISTRY.gaugeInt64(
                    "writer_metric_provider_objects",
                    Labels.of("type", typeCount.getType())
            );
            objectsSensor.set(typeCount.getCount());
        }
    }

    private void addTableSizeMetrics() {
        List<String> tables = new ArrayList<>();
        tables.add(TableNames.DICT_TABLE);
        tables.add(TableNames.WRITE_USER_ACTION_LOG_TABLE);
        for (String replica : allReplicas) {
            DatabaseWrapper wrapper = wrapperProvider.get(replica);
            if (!wrapper.isAlive()) {
                continue;
            }
            Collection<Pair<String, Long>> tableSizes = getTableSizes(wrapper, tables);
            Labels commonLabels = Labels.of("db_name", wrapper.getDbname());
            for (Pair<String, Long> tableSize : tableSizes) {

                GaugeDouble sizeSensor = SOLOMON_REGISTRY.gaugeDouble(
                        "writer_metric_provider_table_size_gigabytes",
                        commonLabels.add("table", tableSize.getLeft())
                );
                sizeSensor.set(tableSize.getRight() / GB);
            }
        }
    }

    private Collection<Pair<String, Long>> getTableSizes(DatabaseWrapper wrapper, Collection<String> tables) {
        String sql = String.format(
                "SELECT table, sum(bytes) FROM system.parts WHERE table IN (%s) AND active GROUP BY table",
                String.join(", ", Collections.nCopies(tables.size(), "?")));
        Collection<Pair<String, Long>> result = new ArrayList<>();
        wrapper.getDslContext().connection(connection -> {
            try (PreparedStatement statement = connection.prepareStatement(sql)) {
                int index = 1;
                for (String table : tables) {
                    statement.setString(index++, table);
                }
                ResultSet resultSet = statement.executeQuery();
                while (resultSet.next()) {
                    result.add(Pair.of(resultSet.getString(1), resultSet.getLong(2)));
                }
            }
        });
        return result;
    }
}
