package ru.yandex.solomon.alert.dao.ydb.entity;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import com.google.protobuf.ByteString;
import com.google.protobuf.UnsafeByteOperations;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.description.TableDescription;
import com.yandex.ydb.table.query.Params;
import com.yandex.ydb.table.result.ResultSetReader;
import com.yandex.ydb.table.settings.ReadTableSettings;
import com.yandex.ydb.table.values.PrimitiveType;
import com.yandex.ydb.table.values.TupleValue;
import io.grpc.Status;

import ru.yandex.solomon.codec.serializer.ByteStringsStockpile;
import ru.yandex.solomon.core.db.dao.kikimr.QueryTemplate;
import ru.yandex.solomon.core.db.dao.kikimr.QueryText;
import ru.yandex.solomon.ydb.YdbTable;

import static com.yandex.ydb.table.values.PrimitiveValue.string;
import static com.yandex.ydb.table.values.PrimitiveValue.uint32;
import static com.yandex.ydb.table.values.PrimitiveValue.utf8;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.failedFuture;

/**
 * @author Vladimir Gordiychuk
 */
public class YdbAlertStateChunksDao implements AlertStatesChunksDao {
    private static final int CHUNK_SIZE = 4 << 20; // 4 MiB

    private static final QueryTemplate TEMPLATE = new QueryTemplate(
        YdbTelegramDao.class,
        "alert_state_chunks",
        List.of(
            "delete_all",
            "upload",
            "delete_file",
            "select_one"));

    private final String tablePath;
    private final ChunksTable table;
    private final QueryText queryText;

    public YdbAlertStateChunksDao(String path, TableClient tableClient) {
        this.tablePath = path + "/StatesChunks";
        this.table = new ChunksTable(tableClient, tablePath);
        this.queryText = TEMPLATE.build(Collections.singletonMap("alert_state_chunks.table.path", tablePath));
    }

    @Override
    public CompletableFuture<?> createSchemaForTests() {
        return table.create();
    }

    public CompletableFuture<?> deleteProject(String projectId) {
        try {
            String query = queryText.query("delete_all");
            Params params = Params.of("$projectId", utf8(projectId));
            return table.queryVoid(query, params);
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }

    @Override
    public CompletableFuture<Integer> uploadChunks(String projectId, String fileId, ByteString bytes) {
        var split = ByteStringsStockpile.split(bytes, CHUNK_SIZE);

        CompletableFuture<?> future = null;
        for (int index = 0; index < split.length; index++) {
            var chunk = new Chunk();
            chunk.projectId = projectId;
            chunk.fileId = fileId;
            chunk.num = index;
            chunk.bytes = split[index];
            if (future == null) {
                future = uploadChunk(chunk);
            } else {
                future = future.thenCompose(ignore -> uploadChunk(chunk));
            }
        }
        int size = split.length;
        return requireNonNull(future).thenApply(ignore -> size);
    }

    private CompletableFuture<?> uploadChunk(Chunk record) {
        try {
            var query = queryText.query("upload");
            var params = table.toParams(record);
            return table.queryVoid(query, params);
        } catch (Throwable e) {
            return failedFuture(e);
        }
    }

    @Override
    public CompletableFuture<ByteString> downloadChunks(String projectId, String fileId, int countChunks) {
        if (countChunks == 0) {
            return CompletableFuture.completedFuture(ByteString.EMPTY);
        } else if (countChunks == 1) {
            return downloadChunksAsSelect(projectId, fileId);
        } else {
            return downloadChunksAsReadtable(projectId, fileId, countChunks);
        }
    }

    private CompletableFuture<ByteString> downloadChunksAsSelect(String projectId, String fileId) {
        try {
            var query = queryText.query("select_one");
            Params params = Params.of("$projectId", utf8(projectId), "$fileId", utf8(fileId));
            return table.queryOne(query, params)
                    .thenApply(chunk -> {
                        if (chunk.isEmpty()) {
                            throw Status.DATA_LOSS
                                    .withDescription("Loaded chunks 0 for file " + fileId + " at project "+ projectId + " but expected 1")
                                    .asRuntimeException();
                        }

                        return chunk.get().bytes;
                    });
        } catch (Throwable e) {
            return failedFuture(e);
        }
    }

    private CompletableFuture<ByteString> downloadChunksAsReadtable(String projectId, String fileId, int countChunks) {
        var key = TupleValue.of(utf8(projectId).makeOptional(), utf8(fileId).makeOptional());
        ReadTableSettings settings = ReadTableSettings.newBuilder()
                .orderedRead(true)
                .timeout(1, TimeUnit.MINUTES)
                .fromKeyInclusive(key)
                .toKeyInclusive(key)
                .build();

        return table.queryAll(ignore -> true, settings)
                .thenApply(chunks -> {
                    List<Chunk> sorted = new ArrayList<>(chunks);
                    sorted.sort(Comparator.comparingInt(o -> o.num));
                    ensureChunksConsistent(projectId, fileId, sorted, countChunks);
                    return sorted.stream()
                            .map(chunk -> chunk.bytes)
                            .reduce(ByteString.EMPTY, ByteString::concat);
                });
    }

    private void ensureChunksConsistent(String projectId, String fileName, List<Chunk> chunks, int expectedCount) {
        for (int index = 0; index < chunks.size(); index++) {
            var chunk = chunks.get(index);
            if (chunk.num != index) {
                throw Status.DATA_LOSS
                    .withDescription("Not found chunk with index " + index + " for file " + fileName + " at project "+ projectId)
                    .asRuntimeException();
            }
        }

        if (chunks.size() != expectedCount) {
            throw Status.DATA_LOSS
                .withDescription("Loaded chunks " + chunks.size() + " for file " + fileName + " at project "+ projectId + " but expected " + expectedCount)
                .asRuntimeException();
        }
    }

    public CompletableFuture<?> deleteFileChunks(String projectId, String fileId) {
        try {
            String query = queryText.query("delete_file");
            Params params = Params.of("$projectId", utf8(projectId), "$fileId", utf8(fileId));
            return table.queryVoid(query, params);
        } catch (Throwable t) {
            return failedFuture(t);
        }
    }

    public int getChunkSize() {
        return CHUNK_SIZE;
    }

    /**
     * CHUNKS TABLE
     */
    private static final class ChunksTable extends YdbTable<String, Chunk> {
        ChunksTable(TableClient tableClient, String path) {
            super(tableClient, path);
        }

        @Override
        protected TableDescription description() {
            return TableDescription.newBuilder()
                .addNullableColumn("projectId", PrimitiveType.utf8())
                .addNullableColumn("fileId", PrimitiveType.utf8())
                .addNullableColumn("num", PrimitiveType.uint32())
                .addNullableColumn("bytes", PrimitiveType.string())
                .setPrimaryKeys("projectId", "fileId", "num")
                .build();
        }

        @Override
        protected String getId(Chunk record) {
            return record.fileId + "_" + record.num;
        }

        @Override
        protected Params toParams(Chunk record) {
            return Params.create()
                .put("$projectId", utf8(record.projectId))
                .put("$fileId", utf8(record.fileId))
                .put("$num", uint32(record.num))
                .put("$bytes", string(record.bytes));
        }

        @Override
        protected Chunk mapFull(ResultSetReader r) {
            var chunk = new Chunk();
            chunk.projectId = r.getColumn("projectId").getUtf8();
            chunk.fileId = r.getColumn("fileId").getUtf8();
            chunk.num = (int) r.getColumn("num").getUint32();
            chunk.bytes = UnsafeByteOperations.unsafeWrap(r.getColumn("bytes").getString());
            return chunk;
        }

        @Override
        protected Chunk mapPartial(ResultSetReader r) {
            return mapFull(r);
        }
    }

    private static class Chunk {
        String projectId;
        String fileId;
        int num;
        ByteString bytes;
    }
}
