package ru.yandex.solomon.experiments.gordiychuk.recovery.metabase;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.settings.ReadTableSettings;
import com.yandex.ydb.table.values.TupleValue;

import ru.yandex.monlib.metrics.primitives.GaugeInt64;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.experiments.gordiychuk.recovery.Record;
import ru.yandex.solomon.tool.YdbClient;
import ru.yandex.solomon.util.future.RetryCompletableFuture;
import ru.yandex.solomon.util.future.RetryConfig;

import static com.yandex.ydb.table.values.PrimitiveValue.uint32;
import static com.yandex.ydb.table.values.PrimitiveValue.utf8;

/**
 * @author Vladimir Gordiychuk
 */
public class MetabaseShardMetricsReader {
    private static final RetryConfig RETRY_CONFIG = RetryConfig.DEFAULT.withDelay(1_000).withMaxDelay(60_000).withNumRetries(1000);

    private final String shardId;
    private final String path;
    private final YdbClient client;
    private final long ageBeforeMillis;
    private final AtomicLong read = new AtomicLong();
    private final GaugeInt64 readMetrics = MetricRegistry.root().gaugeInt64("metabase.read.metrics");

    public MetabaseShardMetricsReader(String root, YdbClient client, String shardId, long ageBeforeMillis) {
        this.client = client;
        this.shardId = shardId;
        this.path = root + "/Solomon/metrics/" + shardId;
        this.ageBeforeMillis = ageBeforeMillis;
    }

    private CompletableFuture<Long> getMetricsCount() {
        String query = String.format("""
                --!syntax_v1
                select cast(count(*) as Uint64) from `%s`;
                """, path);
        return client.fluent().execute(query).thenApply(result -> {
            var resultSet = result.expect("success").getResultSet(0);
            if (!resultSet.next()) {
                return 0L;
            }

            return resultSet.getColumn(0).getUint64();
        });
    }

    public CompletableFuture<Void> run(Consumer<Record> consumer) {
        return CompletableFuture.completedFuture(null)
                .thenCompose(ignore -> getMetricsCount())
                .thenCompose(metrics -> {
                    MetricsReader reader = new MetricsReader(ageBeforeMillis);
                    return RetryCompletableFuture.runWithRetries(() -> {
                        return client.fluent().executeOnSession(session -> {
                            var from = reader.lastKey();
                            System.out.println("Read table " + path + " from " + from + "...");
                            var settings = ReadTableSettings.newBuilder()
                                    .orderedRead(true)
                                    .timeout(1, TimeUnit.MINUTES);
                            if (from != null) {
                                settings.fromKeyExclusive(from);
                            }

                            return session.readTable(path, settings.build(), resultSet -> {
                                reader.read(resultSet, consumer);
                                read.addAndGet(resultSet.getRowCount());
                                readMetrics.add(resultSet.getRowCount());

                                double progress = read.get() * 100. / metrics;
                                System.out.println("Read table " + path + " " + String.format("%.2f%%", progress));
                                MetricRegistry.root().gaugeDouble("merge.metabase.progress").set(progress);
                            });
                        }).thenAccept(status -> status.expect("can not read table " + path));
                    }, RETRY_CONFIG);
                }).thenAccept(unit -> {
                    System.out.println("Read table " + path + " done");
                });
    }

    /**
     * METRICS READER
     */
    private static final class MetricsReader {
        private final long ageSecondsFilter;

        private int hashIdx = -1;
        private int labelsIdx = -1;
        private int shardIdIdx = -1;
        private int localIdIdx = -1;
        private int createdSecondsIdx = -1;
        private int flagsIdx = -1;
        private int lastHash;
        private String lastLabels;

        public MetricsReader(long ageSecondsFilter) {
            this.ageSecondsFilter = ageSecondsFilter;
        }

        void read(ResultSetReader rs, Consumer<Record> fn) {
            final int rowsCount = rs.getRowCount();
            if (rowsCount == 0) {
                return;
            }

            if (labelsIdx == -1) {
                hashIdx = rs.getColumnIndex("hash");
                labelsIdx = rs.getColumnIndex("labels");
                shardIdIdx = rs.getColumnIndex("shardId");
                localIdIdx = rs.getColumnIndex("localId");
                createdSecondsIdx = rs.getColumnIndex("createdSeconds");
                flagsIdx = rs.getColumnIndex("flags");
            }

            while (rs.next()) {
                lastHash = (int) rs.getColumn(hashIdx).getUint32();
                lastLabels = rs.getColumn(labelsIdx).getUtf8();

                Record record = new Record();
                record.labels = lastLabels;
                record.shardId = (int) rs.getColumn(shardIdIdx).getUint32();
                record.localId = rs.getColumn(localIdIdx).getUint64();
                record.flags = (int) rs.getColumn(flagsIdx).getUint32();

                if (ageSecondsFilter == 0 || ageSecondsFilter >= rs.getColumn(createdSecondsIdx).getUint64()) {
                    fn.accept(record);
                }
            }
        }

        public TupleValue lastKey() {
            if (lastLabels == null) {
                return null;
            }

            return TupleValue.of(
                    uint32(lastHash).makeOptional(),
                    utf8(lastLabels).makeOptional());
        }
    }
}
