package ru.yandex.solomon.ydb;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.google.common.collect.Lists;
import com.yandex.ydb.table.SessionRetryContext;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.settings.BulkUpsertSettings;
import com.yandex.ydb.table.settings.PartitioningSettings;
import com.yandex.ydb.table.values.ListValue;
import com.yandex.ydb.table.values.PrimitiveValue;
import com.yandex.ydb.table.values.StructType;
import com.yandex.ydb.table.values.TupleValue;
import com.yandex.ydb.table.values.Value;
import io.grpc.Status;
import io.grpc.Status.Code;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;

import ru.yandex.solomon.kikimr.LocalKikimr;
import ru.yandex.solomon.kikimr.YdbHelper;

import static com.yandex.ydb.table.values.PrimitiveType.utf8;
import static junit.framework.TestCase.assertEquals;
import static org.hamcrest.number.OrderingComparison.greaterThanOrEqualTo;
import static org.hamcrest.number.OrderingComparison.lessThanOrEqualTo;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;

/**
 * @author Vladimir Gordiychuk
 */
public class YdbReadTableTaskTest {

    @ClassRule
    public static final LocalKikimr kikimr = new LocalKikimr();

    @Rule
    public TestName testName = new TestName();

    private YdbHelper ydb;
    private SessionRetryContext retryCtx;

    @Before
    public void setUp() throws Exception {
        ydb = new YdbHelper(kikimr, this.getClass().getSimpleName() + "_" + testName.getMethodName());
        retryCtx = SessionRetryContext.create(ydb.getTableClient())
                .maxRetries(10)
                .build();
    }

    @After
    public void tearDown() {
        ydb.close();
    }

    @Test(expected = RuntimeException.class)
    public void notExistTable() {
        var task = readTable(nextTablePath())
                .consumer(new YdbResultSetConsumerStub(3))
                .build();

        task.run().join();
    }

    @Test
    public void nothingToRead() {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var consumer = new YdbResultSetConsumerStub(3);
        var task = readTable(tablePath)
                .consumer(consumer)
                .build();
        task.run().join();
        assertEquals(0, consumer.resultSets.size());
    }

    @Test
    public void limitedRead() {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var type = StructType.of("key", utf8(), "value", utf8());
        var values = IntStream.range(0, 10)
                .mapToObj(String::valueOf)
                .map(s -> type.newValue(
                        "key", PrimitiveValue.utf8(s),
                        "value", PrimitiveValue.utf8(s)))
                .collect(Collectors.toList());

        bulkUpsert(tablePath, values).join();

        var consumer = new YdbResultSetConsumerStub(100);
        var task = readTable(tablePath)
                .consumer(consumer)
                .rowLimit(2)
                .build();
        task.run().join();

        var rows = consumer.resultSets.stream().mapToInt(ResultSetReader::getRowCount).sum();
        assertEquals(2, rows);
    }

    @Test
    public void readAllLimitedBatch() {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .setPartitioningSettings(new PartitioningSettings().setMinPartitionsCount(10))
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var type = StructType.of("key", utf8(), "value", utf8());
        var values = IntStream.range(0, 100_000)
                .mapToObj(String::valueOf)
                .map(s -> type.newValue(
                        "key", PrimitiveValue.utf8(s),
                        "value", PrimitiveValue.utf8(s)))
                .collect(Collectors.toList());

        bulkUpsert(tablePath, values).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(tablePath)
                .consumer(consumer)
                .primaryKeys(List.of("key"))
                .continueReadAfterLimit(true)
                .rowLimit(10_000)
                .build();
        task.run().join();


        assertThat("split 100K by 10K", consumer.resultSets.size(), greaterThanOrEqualTo(10));

        var readKeys = new ArrayList<String>();
        consumer.resultSets.forEach(rs -> {
            assertThat("resultSet batch less then limit", rs.getRowCount(), lessThanOrEqualTo(10_000));
            var idx = rs.getColumnIndex("key");
            while (rs.next()) {
                readKeys.add(YdbResultSets.utf8(rs, idx));
            }
        });

        var uniqueKeys = new HashSet<>(readKeys);
        assertEquals(values.size(), readKeys.size());
        assertEquals(values.size(), uniqueKeys.size());
    }

    @Test
    public void readAllUseSelectLimitedBatch() {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .setPartitioningSettings(new PartitioningSettings().setMinPartitionsCount(10))
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var type = StructType.of("key", utf8(), "value", utf8());
        var values = IntStream.range(0, 100_000)
                .mapToObj(String::valueOf)
                .map(s -> type.newValue(
                        "key", PrimitiveValue.utf8(s),
                        "value", PrimitiveValue.utf8(s)))
                .collect(Collectors.toList());

        bulkUpsert(tablePath, values).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(tablePath)
                .consumer(consumer)
                .primaryKeys(List.of("key"))
                .continueReadAfterLimit(true)
                .rowLimit(10_000)
                .useSelect(true)
                .build();
        task.run().join();

        var readKeys = new ArrayList<String>();
        consumer.resultSets.forEach(rs -> {
            assertThat("resultSet batch less then limit", rs.getRowCount(), lessThanOrEqualTo(10_000));
            var idx = rs.getColumnIndex("key");
            while (rs.next()) {
                readKeys.add(YdbResultSets.utf8(rs, idx));
            }
        });

        var uniqueKeys = new HashSet<>(readKeys);
        assertEquals(values.size(), readKeys.size());
        assertEquals(values.size(), uniqueKeys.size());
    }

    @Test
    public void stopRetryOnConsumerError() {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .setPartitioningSettings(new PartitioningSettings().setMinPartitionsCount(10))
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var type = StructType.of("key", utf8(), "value", utf8());
        var values = IntStream.range(0, 100)
                .mapToObj(String::valueOf)
                .map(s -> type.newValue(
                        "key", PrimitiveValue.utf8(s),
                        "value", PrimitiveValue.utf8(s + "value")))
                .collect(Collectors.toList());

        bulkUpsert(tablePath, values).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        AtomicInteger attempt = new AtomicInteger();
        consumer.before = () -> {
            attempt.incrementAndGet();
            throw Status.UNAVAILABLE.withDescription("hi").asRuntimeException();
        };

        var task = readTable(tablePath)
                .consumer(consumer)
                .build();

        var status = task.run()
                .thenApply(unused -> Status.OK)
                .exceptionally(Status::fromThrowable)
                .join();

        assertEquals(1, attempt.get());
        assertEquals(Code.UNAVAILABLE, status.getCode());
    }

    @Test
    public void pauseReadInTheMiddle() throws InterruptedException {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .setPartitioningSettings(new PartitioningSettings().setMinPartitionsCount(10))
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var type = StructType.of("key", utf8(), "value", utf8());
        var values = IntStream.range(0, 100)
                .mapToObj(String::valueOf)
                .map(s -> type.newValue(
                        "key", PrimitiveValue.utf8(s),
                        "value", PrimitiveValue.utf8(s + "value")))
                .collect(Collectors.toList());

        bulkUpsert(tablePath, values).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        CompletableFuture<Void> readyFuture = new CompletableFuture<>();
        consumer.after = () -> consumer.readyFuture = readyFuture;

        var task = readTable(tablePath)
                .consumer(consumer)
                .rowLimit(10)
                .continueReadAfterLimit(true)
                .primaryKeys(List.of("key"))
                .build();

        var taskFuture = task.run();
        var resultSet = consumer.resultSets.take();
        assertThat("resultSet batch less then limit", resultSet.getRowCount(), lessThanOrEqualTo(10));
        assertNull(consumer.resultSets.poll(10, TimeUnit.MILLISECONDS));

        readyFuture.completeExceptionally(Status.RESOURCE_EXHAUSTED.asRuntimeException());

        var status = taskFuture
                .thenApply(unused -> Status.OK)
                .exceptionally(Status::fromThrowable)
                .join();

        assertEquals(Code.RESOURCE_EXHAUSTED, status.getCode());
    }

    @Test
    public void compositeKeyReadAll() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "2"));

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records, result);
    }

    @Test
    public void compositeKeyReadFromKeyExclusive() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "2"));

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .fromKey(TupleValue.of(PrimitiveValue.utf8("alice").makeOptional(), PrimitiveValue.utf8("001").makeOptional()), false)
                .continueReadAfterLimit(true)
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records.subList(1, records.size()), result);
    }

    @Test
    public void compositeKeyReadFromKeyInclusive() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "2"));

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .fromKey(TupleValue.of(PrimitiveValue.utf8("alice").makeOptional(), PrimitiveValue.utf8("002").makeOptional()), true)
                .continueReadAfterLimit(true)
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records.subList(1, records.size()), result);
    }

    @Test
    public void compositeKeyReadToKeyExclusive() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "1"),
                CompositeKeyTable.newRecord("bob", "002", "2"),
                CompositeKeyTable.newRecord("bob", "003", "3")
        );

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .toKey(TupleValue.of(PrimitiveValue.utf8("bob").makeOptional(), PrimitiveValue.utf8("002").makeOptional()), false)
                .continueReadAfterLimit(true)
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records.subList(0, 3), result);
    }

    @Test
    public void compositeKeyReadToKeyInclusive() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "1"),
                CompositeKeyTable.newRecord("bob", "002", "2"),
                CompositeKeyTable.newRecord("bob", "003", "3")
        );

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .toKey(TupleValue.of(PrimitiveValue.utf8("bob").makeOptional(), PrimitiveValue.utf8("001").makeOptional()), true)
                .continueReadAfterLimit(true)
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records.subList(0, 3), result);
    }

    @Test
    public void compositeKeyReadFromKeyInclusiveToKeyExclusive() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "1"),
                CompositeKeyTable.newRecord("bob", "002", "2"),
                CompositeKeyTable.newRecord("bob", "003", "3")
        );

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .fromKey(TupleValue.of(PrimitiveValue.utf8("alice").makeOptional(), PrimitiveValue.utf8("002").makeOptional()), true)
                .toKey(TupleValue.of(PrimitiveValue.utf8("bob").makeOptional(), PrimitiveValue.utf8("001").makeOptional()), false)
                .continueReadAfterLimit(true)
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records.subList(1, 2), result);
    }

    @Test
    public void compositeKeyReadAllRowLimit() {
        var table = new CompositeKeyTable(nextTablePath());
        createTable(table.path, table.description()).join();

        var records = List.of(
                CompositeKeyTable.newRecord("alice", "001", "1"),
                CompositeKeyTable.newRecord("alice", "002", "2"),
                CompositeKeyTable.newRecord("bob", "001", "2"));

        bulkUpsert(table.path, CompositeKeyTable.toValues(records)).join();

        var consumer = new YdbResultSetConsumerStub(1000);
        var task = readTable(table.path)
                .consumer(consumer)
                .primaryKeys(table.primaryKeys())
                .rowLimit(2)
                .continueReadAfterLimit(true)
                .build();
        task.run().join();

        var result = consumer.resultSets.stream()
                .flatMap(resultSetReader -> CompositeKeyTable.records(resultSetReader).stream())
                .collect(Collectors.toList());

        assertEquals(records, result);
    }

    @Test
    public void asyncConsumeResultSets() {
        var description = TableDescription.newBuilder()
                .addNullableColumn("key", utf8())
                .addNullableColumn("value", utf8())
                .setPrimaryKey("key")
                .setPartitioningSettings(new PartitioningSettings().setMinPartitionsCount(10))
                .build();

        var tablePath = nextTablePath();
        createTable(tablePath, description).join();

        var type = StructType.of("key", utf8(), "value", utf8());
        var values = IntStream.range(0, 100_000)
                .mapToObj(String::valueOf)
                .map(s -> type.newValue(
                        "key", PrimitiveValue.utf8(s),
                        "value", PrimitiveValue.utf8(s)))
                .collect(Collectors.toList());

        bulkUpsert(tablePath, values).join();

        var readKeys = new ArrayList<String>();
        var consumer = new YdbResultSetAsyncConsumer(rs -> {
            assertThat("resultSet batch less then limit", rs.getRowCount(), lessThanOrEqualTo(10_000));
            var idx = rs.getColumnIndex("key");
            while (rs.next()) {
                readKeys.add(YdbResultSets.utf8(rs, idx));
            }
        }, 20_000, ForkJoinPool.commonPool());

        var task = readTable(tablePath)
                .consumer(consumer)
                .primaryKeys(List.of("key"))
                .continueReadAfterLimit(true)
                .rowLimit(10_000)
                .build();
        task.run().join();
        consumer.done().join();

        var uniqueKeys = new HashSet<>(readKeys);
        assertEquals(values.size(), readKeys.size());
        assertEquals(values.size(), uniqueKeys.size());
    }

    private String nextTablePath() {
        return ydb.getRootPath() + "/table_" + UUID.randomUUID();
    }

    private CompletableFuture<Void> bulkUpsert(String tablePath, List<? extends Value<?>> values) {
        var root = new CompletableFuture<Void>();
        var future = root;
        for (var part : Lists.partition(values, 10_000)) {
            var listValues = ListValue.of(part.toArray(Value[]::new));
            future = future.thenCompose(unused -> retryCtx.supplyStatus(session -> {
                return session.executeBulkUpsert(tablePath, listValues, new BulkUpsertSettings());
            }).thenAccept(status -> status.expect("can not bulk upsert")));
        }
        root.complete(null);
        return future;
    }

    private CompletableFuture<Void> createTable(String tablePath, TableDescription description) {
        return retryCtx.supplyStatus(session -> session.createTable(tablePath, description))
                .thenAccept(status -> status.expect("can not create table"));
    }

    public YdbReadTableTask.Builder readTable(String path) {
        return YdbReadTableTask.newBuilder()
                .tablePath(path)
                .retryContext(retryCtx);
    }

    private static class CompositeKeyTable {
        private static final StructType RECORD_TYPE = StructType.of("k1", utf8(), "k2", utf8(), "value", utf8());

        private final String path;

        public CompositeKeyTable(String path) {
            this.path = path;
        }

        public TableDescription description() {
            return TableDescription.newBuilder()
                    .addNullableColumn("k1", utf8())
                    .addNullableColumn("k2", utf8())
                    .addNullableColumn("value", utf8())
                    .setPrimaryKeys("k1", "k2")
                    .setPartitioningSettings(new PartitioningSettings().setMinPartitionsCount(10))
                    .build();
        }

        public List<String> primaryKeys() {
            return description().getPrimaryKeys();
        }

        public static Record newRecord(String k1, String k2, String value) {
            return new Record(k1, k2, value);
        }

        public static List<Record> records(ResultSetReader rs) {
            var result = new ArrayList<Record>(rs.getRowCount());
            int k1Idx = rs.getColumnIndex("k1");
            int k2Idx = rs.getColumnIndex("k2");
            int valueIdx = rs.getColumnIndex("value");
            while (rs.next()) {
                var k1 = YdbResultSets.utf8(rs, k1Idx);
                var k2 = YdbResultSets.utf8(rs, k2Idx);
                var value = YdbResultSets.utf8(rs, valueIdx);

                var record = newRecord(k1, k2, value);
                result.add(record);
            }
            return result;
        }

        public static List<Value<?>> toValues(List<Record> records) {
            return records.stream()
                    .map(record -> RECORD_TYPE.newValue(
                            "k1", PrimitiveValue.utf8(record.k1),
                            "k2", PrimitiveValue.utf8(record.k2),
                            "value", PrimitiveValue.utf8(record.value)
                    ))
                    .collect(Collectors.toList());
        }

        record Record(String k1, String k2, String value){}
    }
}
