package ru.yandex.chemodan.app.djfs.migrator.migrations;

import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.jdbc.core.ColumnMapRowMapper;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgCursorUtils;
import ru.yandex.chemodan.app.djfs.core.user.DjfsUid;
import ru.yandex.chemodan.app.djfs.core.util.StreamUtils;
import ru.yandex.chemodan.app.djfs.migrator.DjfsCopyConfiguration;
import ru.yandex.chemodan.app.djfs.migrator.PgSchema;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;

public class DjfsAlbumsMigration implements DjfsTableMigration {

    @Override
    public void runCopying(DjfsCopyConfiguration migrationConf, PgSchema databaseSchema,
            Runnable callback)
    {
        JdbcTemplate3 srcShard = migrationConf.srcShardJdbcTemplate();
        JdbcTemplate3 dstShard = migrationConf.dstShardJdbcTemplate();

        ListF<Tuple2<byte[], byte[]>> coversToRestore = Cf.arrayList();

        PgCursorUtils.queryWithCursorAsBatches(
                srcShard.getDataSource(), migrationConf.getBaseBatchSize(),
                "SELECT * FROM disk.albums WHERE uid = ?",
                new ColumnMapRowMapper(),
                migrationConf.getUid().asLong()
        ).forEach(fetchedRows -> {
            DjfsMigrationUtil.copyRows(
                    dstShard, databaseSchema, "albums", fetchedRows
                            .stream()
                            .peek(row -> coversToRestore
                                    .add(Tuple2.tuple((byte[]) row.get("id"), (byte[]) row.get("cover_id"))))
                            .peek(DjfsAlbumsMigration::replaceCoverWithNull)
                            .collect(Collectors.toCollection(Cf::arrayList))
            );
            callback.run();
        });

        PgCursorUtils.queryWithCursorAsBatches(
                srcShard.getDataSource(), migrationConf.getBaseBatchSize(),
                "SELECT * FROM disk.album_items WHERE uid = ?",
                new ColumnMapRowMapper(),
                migrationConf.getUid().asLong()
        ).forEach(fetchedRows -> {
            DjfsMigrationUtil.copyRows(dstShard, databaseSchema, "album_items", fetchedRows);
            callback.run();
        });

        StreamUtils.batches(coversToRestore.stream(), migrationConf.getBaseBatchSize())
                .forEach(batch -> DjfsMigrationUtil.withDisabledMigrationLockCheck(
                        dstShard,
                        () -> dstShard.batchUpdate(
                                "UPDATE disk.albums SET cover_id = ? WHERE albums.id = ?",
                                batch.map(pair -> new Object[]{pair.get2(), pair.get1()})
                        )
                ));
    }

    private static void replaceCoverWithNull(Map<String, Object> row) {
        row.put("cover_id", null);
    }

    @Override
    public void checkAllCopied(DjfsCopyConfiguration migrationConf,
            PgSchema sourceSchema)
    {
        JdbcTemplate3 srcShard = migrationConf.srcShardJdbcTemplate();
        JdbcTemplate3 dstShard = migrationConf.dstShardJdbcTemplate();

        DjfsUid uid = migrationConf.getUid();

        DjfsMigrationUtil.checkSameDataInTableForUid(srcShard, dstShard, "album_items", uid, sourceSchema);
        DjfsMigrationUtil.checkSameDataInTableForUid(srcShard, dstShard, "albums", uid, sourceSchema);
    }

    @Override
    public void cleanData(JdbcTemplate3 shard, DjfsUid uid, int batchSize) {
        DjfsMigrationUtil.withDisabledMigrationLockCheck(shard, () ->
                shard.update("UPDATE disk.albums SET cover_id = NULL WHERE uid = ?", uid.asLong())
        );
        boolean deleted;
        do {
            deleted = DjfsMigrationUtil.withDisabledMigrationLockCheck(shard, () ->
                    shard.update(""
                                    + "WITH to_delete(id) AS ("
                                    + "     SELECT id FROM disk.album_items WHERE uid = :uid LIMIT :batchSize"
                                    + ")"
                                    + "DELETE FROM disk.album_items WHERE id = ANY (SELECT id FROM to_delete)",
                            Cf.map(
                                    "uid", uid,
                                    "batchSize", batchSize
                            )
                    ) > 0);
        } while (deleted);
        DjfsMigrationUtil.withDisabledMigrationLockCheck(shard, () ->
                shard.update("DELETE FROM disk.albums WHERE uid = ?", uid.asLong())
        );
    }

    @Override
    public ListF<String> tables() {
        return Cf.list("album_items", "albums");
    }
}
