package ru.yandex.solomon.coremon.balancer.db;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.table.SessionRetryContext;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.DataQueryResult;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.settings.ExecuteDataQuerySettings;
import com.yandex.ydb.table.settings.ReadTableSettings;
import com.yandex.ydb.table.transaction.TxControl;
import com.yandex.ydb.table.values.ListType;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.StructType;
import com.yandex.ydb.table.values.Value;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntSet;

import static java.util.concurrent.CompletableFuture.failedFuture;

/**
 * @author Sergey Polovko
 */
public class YdbShardAssignmentsDao implements ShardAssignmentsDao {

    static final int BATCH_SIZE = 1000;

    private static final StructType ASSIGNMENT_TYPE = StructType.of(
        "shardId", PrimitiveType.uint32(), "host", PrimitiveType.utf8());
    private static final ListType ASSIGNMENTS_LIST_TYPE = ListType.of(ASSIGNMENT_TYPE);

    private static final StructType ID_TYPE = StructType.of("shardId", PrimitiveType.uint32());
    private static final ListType ID_LIST_TYPE = ListType.of(ID_TYPE);

    private final String tablePath;
    private final SessionRetryContext sessionCtx;
    private final String saveQuery;
    private final String deleteQuery;
    private final String changeQuery;

    public YdbShardAssignmentsDao(String rootPath, TableClient tableClient) {
        this.tablePath = rootPath + "/ShardAssignments";
        this.sessionCtx = SessionRetryContext.create(tableClient)
            .maxRetries(10)
            .executor(ForkJoinPool.commonPool())
            .build();

        this.saveQuery = String.format("""
                --!syntax_v1
                DECLARE $rows AS %s;
                REPLACE INTO `%s` SELECT * FROM AS_TABLE($rows);
                """, ASSIGNMENTS_LIST_TYPE, tablePath);

        this.deleteQuery = String.format("""
                --!syntax_v1
                DECLARE $rows AS %s;
                DELETE FROM `%s` ON SELECT * FROM AS_TABLE($rows);
                """, ID_LIST_TYPE, tablePath);

        this.changeQuery = String.format("""
                --!syntax_v1
                DECLARE $shardId AS Uint32;
                DECLARE $host AS Utf8;
                UPDATE `%s` SET host = $host WHERE shardId = $shardId;
                """, tablePath);
    }

    @Override
    public CompletableFuture<ShardAssignments> load() {
        ReadTableSettings settings = ReadTableSettings.newBuilder()
            .timeout(1, TimeUnit.MINUTES)
            .orderedRead(false)
            .build();

        return sessionCtx.supplyResult(session -> {
            Int2ObjectMap<String> result = new Int2ObjectOpenHashMap<>(1000);
            return session.readTable(tablePath, settings, resultSet -> {
                final int shardIdIdx = resultSet.getColumnIndex("shardId");
                final int hostIdx = resultSet.getColumnIndex("host");
                while (resultSet.next()) {
                    result.put(
                        (int) resultSet.getColumn(shardIdIdx).getUint32(),
                        resultSet.getColumn(hostIdx).getUtf8());
                }
            }).thenApply(s -> s.isSuccess() ? Result.success(ShardAssignments.ownOf(result)) : Result.fail(s));
        }).thenApply(result -> result.expect("cannot load assignments"));
    }

    @Override
    public CompletableFuture<Void> save(ShardAssignments shard2Host) {
        List<Value> values = new ArrayList<>(shard2Host.size());
        for (Int2ObjectMap.Entry<String> e : shard2Host.asMap().int2ObjectEntrySet()) {
            int shardId = e.getIntKey();
            String host = e.getValue();
            values.add(ASSIGNMENT_TYPE.newValueUnsafe(
                // NOTE: unsafe construction, do not modify order of parameters!
                PrimitiveValue.utf8(host),
                PrimitiveValue.uint32(shardId)
            ));
        }

        if (values.size() <= BATCH_SIZE) {
            return saveBatch(values);
        }

        // split all values into N batches and save them with inFlight=1

        CompletableFuture<Void> future = CompletableFuture.completedFuture(null);
        for (List<Value> batch : Lists.partition(values, BATCH_SIZE)) {
            future = future.thenCompose(aVoid -> saveBatch(batch));
        }
        return future;
    }

    private CompletableFuture<Void> saveBatch(List<Value> values) {
        Params params = Params.of("$rows", ASSIGNMENTS_LIST_TYPE.newValue(values));
        return execute(saveQuery, params)
            .thenAccept(result -> {
                result.expect("cannot save assignments");
            });
    }

    @Override
    public CompletableFuture<Void> delete(IntSet shardIds) {
        List<Value> values = new ArrayList<>(shardIds.size());
        for (var it = shardIds.iterator(); it.hasNext(); ) {
            int shardId = it.nextInt();
            values.add(ID_TYPE.newValueUnsafe(PrimitiveValue.uint32(shardId)));
        }

        if (values.size() <= BATCH_SIZE) {
            return deleteBatch(values);
        }

        // split all values into N batches and delete them with inFlight=1

        CompletableFuture<Void> future = CompletableFuture.completedFuture(null);
        for (List<Value> batch : Lists.partition(values, BATCH_SIZE)) {
            future = future.thenCompose(aVoid -> deleteBatch(batch));
        }
        return future;
    }

    private CompletableFuture<Void> deleteBatch(List<Value> values) {
        Params params = Params.of("$rows", ID_LIST_TYPE.newValue(values));
        return execute(deleteQuery, params)
            .thenAccept(result -> {
                result.expect("cannot delete assignments");
            });
    }

    @Override
    public CompletableFuture<Void> update(int shardId, String host) {
        Params params = Params.of("$shardId", PrimitiveValue.uint32(shardId), "$host", PrimitiveValue.utf8(host));
        return execute(changeQuery, params)
            .thenAccept(result -> {
                result.expect("cannot change assignment (" + shardId + ", " + host + ')');
            });
    }

    @VisibleForTesting
    public CompletableFuture<Void> createSchemaForTests() {
        var tableDesc = TableDescription.newBuilder()
            .addNullableColumn("shardId", PrimitiveType.uint32())
            .addNullableColumn("host", PrimitiveType.utf8())
            .setPrimaryKey("shardId")
            .build();

        return sessionCtx.supplyStatus(s -> s.createTable(tablePath, tableDesc))
            .thenAccept(status -> status.expect("cannot create table " + tablePath));
    }

    @VisibleForTesting
    CompletableFuture<Void> dropSchemaForTests() {
        return sessionCtx.supplyStatus(s -> s.dropTable(tablePath))
            .thenAccept(status -> status.expect("cannot drop table " + tablePath));
    }

    private CompletableFuture<Result<DataQueryResult>> execute(String query, Params params) {
        try {
            return sessionCtx.supplyResult(s -> {
                var settings = new ExecuteDataQuerySettings().keepInQueryCache();
                var tx = TxControl.serializableRw();
                return s.executeDataQuery(query, tx, params, settings);
            });
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }
}
