package ru.yandex.solomon.coremon.balancer.db.ydb;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.Lists;
import com.yandex.ydb.table.SchemeClient;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.settings.BulkUpsertSettings;
import com.yandex.ydb.table.values.ListValue;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.StructType;
import com.yandex.ydb.table.values.Value;

import ru.yandex.solomon.coremon.balancer.db.BalancerShard;
import ru.yandex.solomon.coremon.balancer.db.BalancerShardsDao;
import ru.yandex.solomon.util.actors.AsyncActorBody;
import ru.yandex.solomon.util.actors.AsyncActorRunner;
import ru.yandex.solomon.ydb.YdbTable;

import static com.yandex.ydb.table.values.PrimitiveValue.timestamp;
import static com.yandex.ydb.table.values.PrimitiveValue.utf8;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static ru.yandex.misc.concurrent.CompletableFutures.safeCall;

/**
 * @author Stanislav Kashirin
 */
@ParametersAreNonnullByDefault
public class YdbBalancerShardsDao implements BalancerShardsDao {

    private static final int BATCH_MAX_SIZE = 1000;
    private static final int ASYNC_IN_FLIGHT = 10;
    private static final Duration BULK_UPSERT_TIMEOUT = Duration.ofMinutes(1);

    private final String root;
    private final String tablePath;
    private final Table table;
    private final SchemeClient scheme;

    private final YdbBalancerShardsQuery query;

    public YdbBalancerShardsDao(String root, TableClient tableClient, SchemeClient schemeClient) {
        this.root = root;
        this.tablePath = root + "/BalancerShards";
        this.table = new Table(tableClient, tablePath);
        this.scheme = schemeClient;
        this.query = new YdbBalancerShardsQuery(tablePath);
    }

    @Override
    public CompletableFuture<Void> createSchemaForTests() {
        return scheme.makeDirectories(root)
            .thenAccept(status -> status.expect("parent directories"))
            .thenCompose(i -> scheme.describePath(tablePath))
            .thenCompose(
                exist -> !exist.isSuccess()
                    ? table.create()
                    : completedFuture(null));
    }

    @Override
    public CompletableFuture<Void> dropSchemaForTests() {
        return table.drop();
    }

    @Override
    public CompletableFuture<Void> upsert(BalancerShard shard) {
        return safeCall(() -> table.queryVoid(query.upsert, table.toParams(shard)));
    }

    @Override
    public CompletableFuture<Void> delete(String shardId) {
        return safeCall(() -> {
            var params = Params.of("$id", utf8(shardId));
            return table.queryVoid(query.delete, params);
        });
    }

    @Override
    public CompletableFuture<List<BalancerShard>> findAll() {
        return table.queryAll();
    }

    @Override
    public CompletableFuture<Void> bulkUpsert(List<BalancerShard> shards) {
        return safeCall(() -> {
            if (shards.isEmpty()) {
                return completedFuture(null);
            }

            if (shards.size() <= BATCH_MAX_SIZE) {
                return doBulkUpsert(shards);
            }

            var batchIdx = new AtomicInteger(0);
            var batches = Lists.partition(shards, BATCH_MAX_SIZE);
            AsyncActorBody body = () -> {
                int idx = batchIdx.getAndIncrement();
                if (idx >= batches.size()) {
                    return completedFuture(AsyncActorBody.DONE_MARKER);
                }
                return doBulkUpsert(batches.get(idx));
            };

            return new AsyncActorRunner(body, ForkJoinPool.commonPool(), ASYNC_IN_FLIGHT).start();
        });
    }

    private CompletableFuture<Void> doBulkUpsert(List<BalancerShard> batch) {
        var listValue = table.shardsToListValue(batch);
        return table.retryCtx().supplyStatus(
                session -> session.executeBulkUpsert(
                    table.getPath(),
                    listValue,
                    new BulkUpsertSettings().setTimeout(BULK_UPSERT_TIMEOUT)))
            .thenAccept(status -> status.expect("bulk upsert"));
    }

    @ParametersAreNonnullByDefault
    private static final class Table extends YdbTable<String, BalancerShard> {

        private static final StructType BALANCER_SHARD_TYPE = StructType.of(
            Map.of(
                "id", PrimitiveType.utf8(),
                "createdAt", PrimitiveType.timestamp()));

        Table(TableClient tableClient, String path) {
            super(tableClient, path);
        }

        ListValue shardsToListValue(List<BalancerShard> shards) {
            assert !shards.isEmpty();

            var values = new Value<?>[shards.size()];
            for (int i = 0; i < shards.size(); i++) {
                var shard = shards.get(i);

                // NOTE: unsafe construction, do not modify order of parameters!
                values[i] = BALANCER_SHARD_TYPE.newValueUnsafe(
                    timestamp(TimeUnit.MILLISECONDS.toMicros(shard.createdAt())),
                    utf8(shard.id()));
            }
            return ListValue.of(values);
        }

        @Override
        protected TableDescription description() {
            return TableDescription.newBuilder()
                .addNullableColumn("id", PrimitiveType.utf8())
                .addNullableColumn("createdAt", PrimitiveType.timestamp())
                .setPrimaryKeys("id")
                .build();
        }

        @Override
        protected String getId(BalancerShard shard) {
            return shard.id();
        }

        @Override
        protected Params toParams(BalancerShard shard) {
            return Params.create()
                .put("$id", utf8(shard.id()))
                .put("$createdAt", timestamp(TimeUnit.MILLISECONDS.toMicros(shard.createdAt())));
        }

        @Override
        protected BalancerShard mapFull(ResultSetReader r) {
            return new BalancerShard(
                r.getColumn("id").getUtf8(),
                r.getColumn("createdAt").getTimestamp().toEpochMilli());
        }

        @Override
        protected BalancerShard mapPartial(ResultSetReader r) {
            return mapFull(r);
        }
    }

}
