package ru.yandex.solomon.coremon.meta.db.ydb;

import java.time.Duration;
import java.util.Collection;
import java.util.List;
import java.util.OptionalLong;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import com.google.common.collect.Lists;
import com.google.common.primitives.UnsignedInteger;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.Status;
import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.table.SessionRetryContext;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.DataQueryResult;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.settings.ExecuteDataQuerySettings;
import com.yandex.ydb.table.settings.ReadTableSettings;
import com.yandex.ydb.table.transaction.TxControl;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.TupleValue;
import com.yandex.ydb.table.values.Value;

import ru.yandex.monlib.metrics.labels.LabelAllocator;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.labels.LabelsBuilder;
import ru.yandex.solomon.coremon.meta.CoremonMetricArray;
import ru.yandex.solomon.coremon.meta.db.MetricsDao;
import ru.yandex.solomon.coremon.meta.db.MetricsDaoStats;
import ru.yandex.solomon.util.actors.AsyncActorBody;
import ru.yandex.solomon.util.actors.AsyncActorRunner;

import static com.yandex.ydb.table.values.PrimitiveValue.uint32;
import static com.yandex.ydb.table.values.PrimitiveValue.utf8;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static ru.yandex.solomon.ydb.YdbResultSets.uint64;

/**
 * @author Sergey Polovko
 */
final class YdbMetricsHugeTableDao implements MetricsDao {

    private static final long SELECT_USAGE_THRESHOLD = 1000;
    private static final int BATCH_MAX_SIZE = 1000;
    private static final int ASYNC_IN_FLIGHT = 4;

    private final SessionRetryContext retryCtx;
    private final int shardId;
    private final String tablePath;
    private final MetricsDaoStats stats;
    private final LabelAllocator labelAllocator;

    private final String selectQuery;
    private final String countQuery;
    private final String replaceQuery;
    private final String deleteQuery;
    private final String deleteBatch;

    YdbMetricsHugeTableDao(
        TableClient tableClient,
        int shardId,
        String path,
        MetricsDaoStats stats,
        LabelAllocator labelAllocator)
    {
        this.shardId = shardId;
        this.tablePath = path + "/" + HugeTableSettings.makeName(shardId);

        this.retryCtx = SessionRetryContext.create(tableClient)
            .maxRetries(10)
            .sessionSupplyTimeout(Duration.ofSeconds(30))
            .build();
        this.stats = stats;
        this.labelAllocator = labelAllocator;

        this.selectQuery = String.format("""
                --!syntax_v1
                DECLARE $shardId as Uint32;
                SELECT * FROM `%s` WHERE shardId = $shardId;
                """, tablePath);

        this.countQuery = String.format("""
                --!syntax_v1
                DECLARE $shardId as Uint32;
                SELECT cast(count(*) as Uint64) FROM `%s` WHERE shardId = $shardId;
                """, tablePath);

        this.replaceQuery = String.format("""
                --!syntax_v1
                DECLARE $rows AS %s;

                $new_labels = SELECT n.*
                FROM AS_TABLE($rows) AS n
                    LEFT JOIN `%s` e ON n.shardId = e.shardId AND n.hash = e.hash AND n.labels = e.labels
                WHERE e.shardId IS NULL OR (n.flags <> e.flags AND n.spLocalId = e.spLocalId AND n.spShardId = e.spShardId);
                UPSERT INTO `%s` SELECT * FROM $new_labels;

                SELECT * FROM $new_labels;
                """, MetricConverter.V2.METRICS_LIST_TYPE, tablePath, tablePath);

        this.deleteQuery = String.format("""
                --!syntax_v1
                DECLARE $keys AS %s;
                DELETE FROM `%s` ON SELECT * FROM AS_TABLE($keys);
                """, MetricConverter.V2.KEYS_LIST_TYPE, tablePath);

        this.deleteBatch = String.format("""
                --!syntax_v1
                DECLARE $shardId as Uint32;

                $to_delete = (SELECT shardId, hash, labels FROM `%s` WHERE shardId = $shardId LIMIT 5000);

                SELECT cast(count(*) as Uint64) FROM $to_delete;

                DELETE FROM `%s` ON
                SELECT * FROM $to_delete;
                """, tablePath, tablePath);
    }

    String getTablePath() {
        return tablePath;
    }

    public int getShardId() {
        return shardId;
    }

    @Override
    public CompletableFuture<Void> createSchema() {
        var descriptor = TableDescription.newBuilder()
            .addNullableColumn("shardId", PrimitiveType.uint32())
            .addNullableColumn("hash", PrimitiveType.uint32())
            .addNullableColumn("labels", PrimitiveType.utf8())
            .addNullableColumn("spShardId", PrimitiveType.uint32())
            .addNullableColumn("spLocalId", PrimitiveType.uint64())
            .addNullableColumn("createdAt", PrimitiveType.uint64())
            .addNullableColumn("flags", PrimitiveType.uint32())
            .setPrimaryKeys("shardId", "hash", "labels")
            .build();

        return retryCtx.supplyStatus(s -> s.createTable(tablePath, descriptor))
            .thenAccept(status -> status.expect("cannot create table " + tablePath));
    }

    @Override
    public CompletableFuture<Void> dropSchema() {
        // TODO: implement background deletion task
        return CompletableFuture.completedFuture(null);
    }

    @Override
    public CompletableFuture<Long> getMetricCount() {
        Params params = Params.of("$shardId", uint32(shardId));
        return execute(countQuery, params)
            .thenApply(result -> {
                DataQueryResult dataResult = result.expect("cannot get rows count by shardId=" + Integer.toUnsignedLong(shardId));
                ResultSetReader resultSet = dataResult.getResultSet(0);
                if (!resultSet.next()) {
                    return 0L;
                }
                return uint64(resultSet, 0);
            });
    }

    @Override
    public CompletableFuture<Long> findMetrics(Consumer<CoremonMetricArray> consumer, OptionalLong metricCount) {
        if (metricCount.isPresent() && metricCount.getAsLong() < SELECT_USAGE_THRESHOLD) {
            return findWithSelect(consumer)
                    .thenCompose(totalRows -> {
                        if (totalRows.isPresent()) {
                            return CompletableFuture.completedFuture(totalRows.getAsLong());
                        }
                        // selected result was truncated, so fallback to readtable
                        return findWithReadTable(consumer);
                    });
        }
        return findWithReadTable(consumer);
    }

    private CompletableFuture<OptionalLong> findWithSelect(Consumer<CoremonMetricArray> consumer) {
        Params params = Params.of("$shardId", uint32(shardId));
        var future = execute(selectQuery, params)
                .thenApply(result -> {
                    DataQueryResult dataResult = result.expect("cannot select metrics by shardId=" + Integer.toUnsignedLong(shardId));
                    ResultSetReader resultSet = dataResult.getResultSet(0);

                    if (resultSet.isTruncated()) {
                        // return optional to fallback to readtable
                        return OptionalLong.empty();
                    }

                    MetricsReader reader = new MetricsReader(labelAllocator, stats, consumer);
                    reader.accept(resultSet);

                    return OptionalLong.of(reader.getTotalRows());
                });
        stats.selects.forFuture(future);
        return future;
    }

    private CompletableFuture<Long> findWithReadTable(Consumer<CoremonMetricArray> consumer) {
        ReadTableSettings settings = ReadTableSettings.newBuilder()
            .timeout(30, TimeUnit.MINUTES)
            .fromKeyInclusive(TupleValue.of(
                uint32(shardId).makeOptional(),
                PrimitiveType.uint32().makeOptional().emptyValue(),
                PrimitiveType.utf8().makeOptional().emptyValue()))
            .toKeyInclusive(TupleValue.of(
                uint32(shardId).makeOptional(),
                uint32(UnsignedInteger.MAX_VALUE.intValue()).makeOptional(),
                utf8(LabelListSortedSerialize.SENTINEL).makeOptional()))
            .orderedRead(false)
            .build();

        MetricsReader reader = new MetricsReader(labelAllocator, stats, consumer);
        var future = retryCtx.supplyStatus(session -> {
            return session.readTable(tablePath, settings, reader)
                .thenApply(status -> {
                    if (status.isSuccess()) {
                        return status;
                    }

                    // Do not retry failed reads on this level.
                    // Caller of the method findMetricsWithCallback() must retry it
                    // with cleared collection of previously read metrics.
                    return Status.of(StatusCode.INTERNAL_ERROR, status.getIssues());
                });
        })
        .thenApply(status -> {
            status.expect("cannot read table " + tablePath + " for shard " + Integer.toUnsignedLong(shardId));
            return reader.getTotalRows();
        });

        stats.readTables.forFuture(future);
        return future;
    }

    @Override
    public CompletableFuture<CoremonMetricArray> replaceMetrics(CoremonMetricArray metrics) {
        List<Value> values = MetricConverter.metricsToValues(shardId, metrics);
        if (values.size() < BATCH_MAX_SIZE) {
            return doReplaceImpl(values);
        }

        AtomicInteger batchIdx = new AtomicInteger(0);
        List<List<Value>> batches = Lists.partition(values, BATCH_MAX_SIZE);
        var actualMetrics = new ConcurrentLinkedQueue<CoremonMetricArray>();
        AsyncActorBody body = () -> {
            int idx = batchIdx.getAndIncrement();
            if (idx >= batches.size()) {
                return CompletableFuture.completedFuture(AsyncActorBody.DONE_MARKER);
            }
            return doReplaceImpl(batches.get(idx))
                    .thenApply(m -> actualMetrics.add(m));
        };

        AsyncActorRunner actorRunner = new AsyncActorRunner(body, ForkJoinPool.commonPool(), ASYNC_IN_FLIGHT);
        return actorRunner.start()
                .thenApply(aVoid -> {
                    var newMetrics = new CoremonMetricArray(values.size());
                    var it = actualMetrics.iterator();
                    while (it.hasNext()) {
                        newMetrics.addAll(it.next());
                    }
                    return newMetrics;
                })
                .whenComplete((newMetrics, throwable) -> {
                    var it = actualMetrics.iterator();
                    while (it.hasNext()) {
                        it.next().closeSilent();
                    }
                    if (throwable != null) {
                        throw new RuntimeException(throwable);
                    }
                });
    }

    private CompletableFuture<CoremonMetricArray> doReplaceImpl(List<Value> values) {
        long startNanos = System.nanoTime();
        try {
            int size = values.size();
            stats.insertsCount.add(size);
            Params params = Params.of("$rows", MetricConverter.V2.METRICS_LIST_TYPE.newValue(values));
            return execute(replaceQuery, params)
                    .thenApply(r -> {
                        stats.writeTimeMillis.record(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos), size);
                        var dataResult = r.expect("cannot replace metrics in shard " + Integer.toUnsignedLong(shardId));
                        ResultSetReader resultSet = dataResult.getResultSet(0);

                        var newMetrics = new CoremonMetricArray(values.size());
                        new MetricsReader(labelAllocator, stats, (metrics) -> {
                            newMetrics.addAll(metrics);
                        }).accept(resultSet);
                        return newMetrics;
                    });
        } catch (Throwable t) {
            return CompletableFuture.failedFuture(t);
        }
    }

    @Override
    public CompletableFuture<Void> deleteMetrics(Collection<Labels> keys) {
        long startNanos = System.nanoTime();
        try {
            int size = keys.size();
            stats.deletesCount.add(size);
            Params params = Params.of("$keys", MetricConverter.keysToList(shardId, keys));
            return execute(deleteQuery, params)
                .thenAccept(result -> {
                    result.expect("cannot delete metrics from shard " + Integer.toUnsignedLong(shardId));
                    stats.deleteTimeMillis.record(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos), size);
                });
        } catch (Throwable t) {
            return CompletableFuture.failedFuture(t);
        }
    }

    @Override
    public CompletableFuture<Long> deleteMetricsBatch() {
        Params params = Params.of("$shardId", uint32(shardId));
        return execute(deleteBatch, params)
                .thenApply(result -> {
                    DataQueryResult dataResult = result.expect("cannot get delete metrics by shardId=" + Integer.toUnsignedLong(shardId));
                    ResultSetReader resultSet = dataResult.getResultSet(0);
                    if (!resultSet.next()) {
                        return 0L;
                    }
                    return resultSet.getColumn(0).getUint64();
                });
    }

    private CompletableFuture<Result<DataQueryResult>> execute(String query, Params params) {
        try {
            return retryCtx.supplyResult(s -> {
                var settings = new ExecuteDataQuerySettings().keepInQueryCache();
                var tx = TxControl.serializableRw();
                return s.executeDataQuery(query, tx, params, settings);
            });
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }

    /**
     * METRICS READER
     */
    private static final class MetricsReader implements Consumer<ResultSetReader> {
        private final LabelAllocator labelAllocator;
        private final MetricsDaoStats stats;
        private final Consumer<CoremonMetricArray> fn;
        private int labelsIdx = -1;
        private int shardIdIdx = -1;
        private int localIdIdx = -1;
        private int createdAtIdx = -1;
        private int flagsIdx = -1;
        private long totalRows = 0;
        private long prevNanos = System.nanoTime();

        MetricsReader(LabelAllocator labelAllocator, MetricsDaoStats stats, Consumer<CoremonMetricArray> fn) {
            this.labelAllocator = labelAllocator;
            this.stats = stats;
            this.fn = fn;
        }

        @Override
        public void accept(ResultSetReader resultSet) {
            int rowCount = resultSet.getRowCount();
            if (rowCount == 0) {
                return;
            }

            if (labelsIdx == -1) {
                labelsIdx = resultSet.getColumnIndex("labels");
                shardIdIdx = resultSet.getColumnIndex("spShardId");
                localIdIdx = resultSet.getColumnIndex("spLocalId");
                createdAtIdx = resultSet.getColumnIndex("createdAt");
                flagsIdx = resultSet.getColumnIndex("flags");
            }

            try (CoremonMetricArray metrics = new CoremonMetricArray(rowCount)) {
                LabelsBuilder labelsBuilder = Labels.builder(Labels.MAX_LABELS_COUNT, labelAllocator);
                while (resultSet.next()) {
                    labelsBuilder.clear();
                    String labels = resultSet.getColumn(labelsIdx).getUtf8();
                    metrics.add(
                        (int) resultSet.getColumn(shardIdIdx).getUint32(),
                        resultSet.getColumn(localIdIdx).getUint64(),
                        LabelListSortedSerialize.parse(labels, labelsBuilder),
                        (int) resultSet.getColumn(createdAtIdx).getUint64(),
                        Flags.readMetricType((int) resultSet.getColumn(flagsIdx).getUint32()));
                }

                totalRows += metrics.size();
                fn.accept(metrics);
            } finally {
                long nowNanos = System.nanoTime();
                stats.readTimeMillis.record(TimeUnit.NANOSECONDS.toMillis(nowNanos - prevNanos), rowCount);
                stats.readsCount.add(rowCount);
                prevNanos = nowNanos;
            }
        }

        long getTotalRows() {
            return totalRows;
        }
    }
}
