package ru.yandex.infra.stage.cache;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.google.protobuf.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;
import ru.yandex.infra.controller.metrics.NamespacedGaugeRegistry;
import ru.yandex.inside.yt.kosher.impl.ytree.YTreeProtoUtils;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTreeBuilder;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.yt.rpcproxy.ETransactionType;
import ru.yandex.yt.ytclient.proxy.ApiServiceTransactionOptions;
import ru.yandex.yt.ytclient.proxy.ModifyRowsRequest;
import ru.yandex.yt.ytclient.proxy.SelectRowsRequest;
import ru.yandex.yt.ytclient.proxy.YtClient;
import ru.yandex.yt.ytclient.proxy.request.CreateNode;
import ru.yandex.yt.ytclient.proxy.request.ObjectType;
import ru.yandex.yt.ytclient.tables.ColumnValueType;
import ru.yandex.yt.ytclient.tables.TableSchema;
import ru.yandex.yt.ytclient.wire.UnversionedRow;

public class YtCacheStorage<TProtoValue extends Message> implements CacheStorage<TProtoValue> {

    public static final String COLUMN_NAME_ID = "id";
    public static final String COLUMN_NAME_VALUE = "value";

    private static final String METRIC_TOTAL_TRANSACTIONS_COUNT = "total_transactions_count";
    private static final String METRIC_FAILED_TRANSACTIONS_COUNT = "failed_transactions_count";
    private static final String METRIC_COMPLETED_TRANSACTIONS_COUNT = "completed_transactions_count";

    private static final AtomicBoolean metricsInitialized = new AtomicBoolean(false);
    private static final AtomicLong metricTotalTransactionsCount = new AtomicLong();
    private static final AtomicLong metricFailedTransactionsCount = new AtomicLong();
    private static final AtomicLong metricCompletedTransactionsCount = new AtomicLong();

    private static final Logger LOG = LoggerFactory.getLogger(YtCacheStorage.class);

    private final YtClient yt;
    private final String path;
    private final Supplier<Message.Builder> builderSupplier;
    private final String objectType;
    private final int batchInsertChunkSize;
    private final Optional<Duration> readRequestTimeout;

    private final ApiServiceTransactionOptions transactionOptions = new ApiServiceTransactionOptions(ETransactionType.TT_MASTER)
            .setSticky(true)
            .setPing(true);
    private final TableSchema tableSchema = new TableSchema.Builder()
            .addKey(COLUMN_NAME_ID, ColumnValueType.STRING)
            .addValue(COLUMN_NAME_VALUE, ColumnValueType.ANY)
            .build();

    public YtCacheStorage(YtClient yt,
                          String path,
                          CachedObjectType<?, TProtoValue> cachedObjectType,
                          int batchInsertChunkSize,
                          Optional<Duration> readRequestTimeout,
                          GaugeRegistry gaugeRegistry) {
        this.yt = yt;
        this.path = path + "/" + cachedObjectType.getName();
        this.builderSupplier = cachedObjectType.getBuilderSupplier();
        this.objectType = cachedObjectType.getName();
        this.batchInsertChunkSize = batchInsertChunkSize;
        this.readRequestTimeout = readRequestTimeout;

        if (metricsInitialized.compareAndSet(false, true)) {
            GaugeRegistry registry = new NamespacedGaugeRegistry(gaugeRegistry, "yt");
            registry.add(METRIC_TOTAL_TRANSACTIONS_COUNT, new GolovanableGauge<>(metricTotalTransactionsCount::get, "dmmm"));
            registry.add(METRIC_FAILED_TRANSACTIONS_COUNT, new GolovanableGauge<>(metricFailedTransactionsCount::get, "dmmm"));
            registry.add(METRIC_COMPLETED_TRANSACTIONS_COUNT, new GolovanableGauge<>(metricCompletedTransactionsCount::get, "dmmm"));
        }
    }

    @Override
    public CompletableFuture<?> init() {
        return yt.existsNode(path)
                .thenCompose(exist -> {
                    if(exist) {
                        LOG.info("[{}] Found existed cache table: {}", objectType, path);
                        return CompletableFuture.completedFuture(null);
                    }
                    LOG.info("[{}] Creating table: {}", objectType, path);

                    List<MapF<String, String>> schema = List.of(
                            Cf.map("name", COLUMN_NAME_ID,
                                "type", "string",
                                "required", "true",
                                "sort_order", "ascending"),
                            Cf.map("name", COLUMN_NAME_VALUE,
                                    "type", "any"));

                    MapF<String, YTreeNode> attributes = Cf.map(
                            "dynamic", new YTreeBuilder().value(true).build(),
                            "optimize_for", new YTreeBuilder().value("scan").build(),
                            "schema", new YTreeBuilder().value(schema).build()
                    );

                    final CreateNode createNodeRequest = new CreateNode(path, ObjectType.Table, attributes)
                            .setRecursive(true)
                            .setIgnoreExisting(false);

                    return yt.createNode(createNodeRequest)
                            .thenCompose(tableId -> {
                                LOG.info("Mounting table with id {}: {}", tableId, path);
                                return yt.mountTable(path);
                            })
                            .thenRun(() -> LOG.info("Mounted table: {}", path));
                });
    }

    @Override
    public CompletableFuture<?> flush() {
        return CompletableFuture.completedFuture(null);
    }

    @Override
    public CompletableFuture<Map<String, TProtoValue>> read() {
        SelectRowsRequest request = SelectRowsRequest.of(String.format("%s, %s FROM [%s]", COLUMN_NAME_ID, COLUMN_NAME_VALUE, path));
        readRequestTimeout.ifPresent(request::setTimeout);

        return yt.selectRows(request)
                .thenApply(rowset -> rowset.getRows()
                        .stream()
                        .map(UnversionedRow::getValues)
                        .collect(Collectors.toMap(v -> v.get(0).stringValue(),
                                v -> {
                                    TProtoValue.Builder builder = builderSupplier.get();
                                    YTreeProtoUtils.unmarshal(v.get(1).toYTree(), builder);
                                    return (TProtoValue)builder.build();
                                })));
    }

    @Override
    public CompletableFuture<?> write(Map<String, TProtoValue> values) {
        if(values.isEmpty()) {
            return CompletableFuture.completedFuture(null);
        }

        final AtomicInteger counter = new AtomicInteger();
        Map<Integer, List<Map<String, Object>>> chunks = values.entrySet()
                .stream()
                .sorted(Map.Entry.comparingByKey())
                .map(entry -> Map.of(COLUMN_NAME_ID, entry.getKey(),
                                COLUMN_NAME_VALUE, YTreeProtoUtils.marshal(entry.getValue())))
                .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / batchInsertChunkSize));
        CompletableFuture<?> future = CompletableFuture.completedFuture(null);
        for (int i = 0; i < chunks.size(); i++) {
            List<Map<String, Object>> chunk = chunks.get(i);
            final int startIndex = i * batchInsertChunkSize;
            final int endIndex = Math.min((i + 1) * batchInsertChunkSize - 1, values.size() - 1);
            String chunkName = String.format("keys range [%d; %d]", startIndex, endIndex);
            future = future
                    .thenCompose(x -> modifyRows(chunkName, request -> chunk.forEach(request::addUpdate), "update"))
                    .thenRun(() -> LOG.info("[{}] saved {} rows in range [{}; {}]", objectType, chunk.size(), startIndex, endIndex))
                    .exceptionally(e -> null);
        }
        return future;
    }

    @Override
    public CompletableFuture<?> write(String key, TProtoValue value) {
        return modifyRows(key, request -> request.addUpdate(
                Map.of(COLUMN_NAME_ID, key,
                       COLUMN_NAME_VALUE, YTreeProtoUtils.marshal(value))
        ), "update");
    }

    @Override
    public CompletableFuture<?> remove(String key) {
        return modifyRows(key, request -> request.addDelete(List.of(key)), "remove");
    }

    private CompletableFuture<?> modifyRows(String key, Consumer<ModifyRowsRequest> operation, String actionDescription) {
        metricTotalTransactionsCount.incrementAndGet();
        LOG.info("[{}] Trying to {} row with key {}", objectType, actionDescription, key);
        return yt.startTransaction(transactionOptions)
                .thenCompose(transaction -> {
                    ModifyRowsRequest request = new ModifyRowsRequest(path, tableSchema);
                    operation.accept(request);
                    return transaction.modifyRows(request)
                            .thenCompose(ignore -> transaction.commit())
                            .whenComplete((ignore, error) -> {
                                if (error != null) {
                                    LOG.error("[{}] Failed to {} row with key {}: {}", objectType, actionDescription, key, error);
                                    transaction.abort();
                                }
                            });
                })
                .whenComplete((ignoredResult, error) -> {
                    if (error != null) {
                        metricFailedTransactionsCount.incrementAndGet();
                    } else {
                        metricCompletedTransactionsCount.incrementAndGet();
                    }
                });

    }
}
