package ru.yandex.solomon.ydb;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.values.ListType;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.StructType;
import com.yandex.ydb.table.values.Value;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.solomon.kikimr.LocalKikimr;
import ru.yandex.solomon.kikimr.YdbHelper;
import ru.yandex.solomon.ydb.page.TokenPageOptions;

import static com.yandex.ydb.table.values.PrimitiveValue.int32;
import static com.yandex.ydb.table.values.PrimitiveValue.utf8;
import static java.util.stream.Collectors.toList;
import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
import static org.junit.Assert.assertEquals;
import static ru.yandex.misc.concurrent.CompletableFutures.join;

/**
 * @author Stanislav Kashirin
 */
public class YdbTableTest {

    private static final StructType RECORD_TYPE = StructType.of(Map.of(
        "id", PrimitiveType.utf8(),
        "payload", PrimitiveType.utf8()
    ));

    private static final ListType RECORD_LIST_TYPE = ListType.of(RECORD_TYPE);

    //language=SQL
    private static final String INSERT_MANY = """
            --!syntax_v1
            DECLARE $rows AS LIST_TYPE;
            REPLACE INTO `TABLE_PATH` SELECT * FROM AS_TABLE($rows)
            """;

    //language=SQL
    private static final String LIST_PAGED = """
            --!syntax_v1
            DECLARE $lastId AS Utf8;
            DECLARE $pageSize AS Int32;

            SELECT * FROM `TABLE_PATH`
            WHERE id > $lastId
            LIMIT $pageSize
            """;

    //language=SQL
    private static final String LIST_PAGED_OFFSET = """
            --!syntax_v1
            SELECT COUNT(*) FROM `TABLE_PATH`;

            SELECT * FROM `TABLE_PATH`
            LIMIT ${page.size} OFFSET ${page.offset};
            """;

    @ClassRule
    public static final LocalKikimr kikimr = new LocalKikimr();

    @Rule
    public TestName testName = new TestName();

    private YdbHelper ydb;
    private TestTable table;

    private String insertManyQuery;
    private String listPagedQuery;
    private String listPagedOffsetQuery;

    @Before
    public void setUp() throws Exception {
        ydb = new YdbHelper(kikimr, this.getClass().getSimpleName() + "_" + testName.getMethodName());
        var tablePath = ydb.getRootPath() + "/TestTable";

        table = new TestTable(ydb.getTableClient(), tablePath);
        join(table.create());

        insertManyQuery = INSERT_MANY
            .replaceAll("LIST_TYPE", RECORD_LIST_TYPE.toString())
            .replaceAll("TABLE_PATH", tablePath);
        listPagedQuery = LIST_PAGED.replaceAll("TABLE_PATH", tablePath);
        listPagedOffsetQuery = LIST_PAGED_OFFSET.replaceAll("TABLE_PATH", tablePath);
    }

    @After
    public void tearDown() {
        join(table.drop());
        ydb.close();
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    @Test
    public void pagingWhenYdbTruncatesResultSet() {
        // arrange
        var records = Stream.generate(() -> new TestRecord(randomAlphanumeric(8), randomAlphanumeric(12)))
            .limit(1111)
            .collect(toList());

        var rows = records.stream()
            .map(table::toParams)
            .<Value>map(params -> RECORD_TYPE.newValue((Map) params.values()))
            .collect(toList());

        join(table.execute(insertManyQuery, Params.of("$rows", RECORD_LIST_TYPE.newValue(rows))));

        // act
        var token = "";
        var pages = new ArrayList<List<TestRecord>>();
        do {
            var pageOpts = new TokenPageOptions(1000, token);
            var params = Params.of("$lastId", utf8(token), "$pageSize", int32(1001));
            var page = join(table.queryPage(listPagedQuery, params, pageOpts, TestRecord::id));
            pages.add(page.getItems());
            token = page.getNextPageToken();
        } while (!token.isEmpty());

        // assert
        var expected = records.stream()
            .sorted(Comparator.comparing(TestRecord::id))
            .collect(toList());

        var actual = pages.stream()
            .flatMap(Collection::stream)
            .collect(toList());

        assertEquals("rows selected", expected.size(), actual.size());
        assertEquals(expected, actual);
    }

    @Test
    public void pageSizeAllTimeout() {
        for (int i = 0; i < 100; i++) {
            var records = Stream
                    .generate(() -> new TestRecord(randomAlphanumeric(8), randomAlphanumeric(12)))
                    .limit(1111)
                    .collect(toList());

            var rows = records.stream()
                    .map(table::toParams)
                    .<Value>map(params -> RECORD_TYPE.newValue((Map) params.values()))
                    .collect(toList());

            join(table.execute(insertManyQuery, Params.of("$rows", RECORD_LIST_TYPE.newValue(rows))));
        }

        try {
            var finder = new AsyncFinder<>(table, Params.empty(), pageOpts ->
                    listPagedOffsetQuery
                            .replace("${page.offset}", Integer.toString(pageOpts.getOffset()))
                            .replace("${page.size}", Integer.toString(pageOpts.getSize())));
            finder.findAll(Instant.now().plusSeconds(1));
            finder.getFuture().join();
        } catch (Throwable e) {
            Throwable cause = CompletableFutures.unwrapCompletionException(e);
            if (cause instanceof StatusRuntimeException sre) {
                Assert.assertEquals(Status.Code.DEADLINE_EXCEEDED, sre.getStatus().getCode());
                return;
            }
        }
        Assert.fail("timeout did not occur");
    }

    static class TestTable extends YdbTable<String, TestRecord> {
        protected TestTable(TableClient tableClient, String path) {
            super(tableClient, path);
        }

        @Override
        protected TableDescription description() {
            return TableDescription.newBuilder()
                .addNullableColumn("id", PrimitiveType.utf8())
                .addNullableColumn("payload", PrimitiveType.utf8())
                .setPrimaryKey("id")
                .build();
        }

        @Override
        protected String getId(TestRecord testRecord) {
            return testRecord.id();
        }

        @Override
        protected Params toParams(TestRecord testRecord) {
            return Params.of(
                "id", utf8(testRecord.id()),
                "payload", utf8(testRecord.payload()));
        }

        @Override
        protected TestRecord mapFull(ResultSetReader resultSet) {
            return new TestRecord(
                resultSet.getColumn("id").getUtf8(),
                resultSet.getColumn("payload").getUtf8()
            );
        }

        @Override
        protected TestRecord mapPartial(ResultSetReader resultSet) {
            return mapFull(resultSet);
        }
    }

    static record TestRecord(String id, String payload) {
    }
}
