package ru.yandex.chemodan.app.djfs.migrator.migrations;

import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;

import org.springframework.jdbc.core.ColumnMapRowMapper;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgArray;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgCursorUtils;
import ru.yandex.chemodan.app.djfs.core.db.pg.ResultSetUtils;
import ru.yandex.chemodan.app.djfs.core.user.DjfsUid;
import ru.yandex.chemodan.app.djfs.migrator.DjfsCopyConfiguration;
import ru.yandex.chemodan.app.djfs.migrator.PgSchema;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;

public class DjfsFilesDataMigration implements DjfsTableMigration {
    private static final Logger logger = LoggerFactory.getLogger(DjfsFilesDataMigration.class);

    @Override
    public void runCopying(DjfsCopyConfiguration migrationConf, PgSchema databaseSchema,
            Runnable callback)
    {
        JdbcTemplate3 srcShard = migrationConf.srcShardJdbcTemplate();
        JdbcTemplate3 dstShard = migrationConf.dstShardJdbcTemplate();

        copyFiles(migrationConf, databaseSchema, srcShard, dstShard, callback);

        copyVersionData(migrationConf, databaseSchema, srcShard, dstShard, callback);
    }

    private void copyVersionData(DjfsCopyConfiguration migrationConf, PgSchema databaseSchema,
            JdbcTemplate3 srcShard, JdbcTemplate3 dstShard, Runnable callback)
    {
        AtomicReference<Duration> storageWriteDuration = new AtomicReference<>(Duration.ZERO);
        AtomicReference<Duration> storageReadDuration = new AtomicReference<>(Duration.ZERO);

        Instant beginVersionData = Instant.now();
        AtomicReference<Duration> writeTimeVersionData = new AtomicReference<>(Duration.ZERO);

        PgCursorUtils.queryWithCursorAsBatches(
                srcShard.getDataSource(), migrationConf.getBaseBatchSize(),
                "SELECT * FROM disk.version_data WHERE uid = ?",
                new ColumnMapRowMapper(),
                migrationConf.getUid().asLong()
        ).forEach(fetchedRows -> {
            ListF<UUID> storageIds = fetchedRows.map(m -> (UUID) m.get("storage_id")).filterNotNull();

            ListF<Map<String, Object>> storageFiles =
                    fetchStorageFiles(srcShard, storageReadDuration, storageIds);
            ListF<Tuple2<UUID, String>> storageDuplicates =
                    fetchStorageDuplicates(srcShard, storageReadDuration, storageIds);

            DjfsMigrationUtil.doInTransaction(dstShard.getDataSource(), () -> {
                copyStorageFiles(databaseSchema, dstShard, storageWriteDuration, storageFiles,
                        storageDuplicates);

                Instant beginWrite = Instant.now();
                DjfsMigrationUtil.copyRows(dstShard, databaseSchema, "version_data", fetchedRows);
                writeTimeVersionData
                        .accumulateAndGet(Duration.between(beginWrite, Instant.now()), Duration::plus);
            });
            callback.run();
        });

        Duration readDurationsStorageVersionData = storageReadDuration.get();
        Duration writeDurationsStorageVersionData = storageWriteDuration.get();
        Duration overallDurationStorageVersionData =
                readDurationsStorageVersionData.plus(writeDurationsStorageVersionData);
        logger.info("storage (version data): overallTime {}, read time {}, write time {}",
                overallDurationStorageVersionData,
                readDurationsStorageVersionData,
                writeDurationsStorageVersionData
        );

        Duration overallTime =
                Duration.between(beginVersionData, Instant.now()).minus(overallDurationStorageVersionData);
        Duration writeTimeSum = writeTimeVersionData.get();
        logger.info("version data: overallTime {}, read time {}, write time {}",
                overallTime,
                overallTime.minus(writeTimeSum),
                writeTimeSum
        );
    }

    private ListF<Map<String, Object>> fetchStorageFiles(JdbcTemplate3 srcShard,
            AtomicReference<Duration> storageReadDuration, ListF<UUID> storageIds)
    {
        ListF<Map<String, Object>> fetchedStorageFiles;
        if (storageIds.isNotEmpty()) {
            Instant readBegin = Instant.now();
            fetchedStorageFiles = srcShard.query(
                    "SELECT * FROM disk.storage_files WHERE storage_id = ANY (:ids)",
                    new ColumnMapRowMapper(), Cf.map("ids", PgArray.uuidArray(storageIds.toArray(new UUID[0])))
            );
            storageReadDuration
                    .accumulateAndGet(Duration.between(readBegin, Instant.now()), Duration::plus);
        } else {
            fetchedStorageFiles = Cf.list();
        }
        return fetchedStorageFiles;
    }

    private ListF<Tuple2<UUID, String>> fetchStorageDuplicates(JdbcTemplate3 srcShard,
            AtomicReference<Duration> storageReadDuration, ListF<UUID> storageIds)
    {
        ListF<Tuple2<UUID, String>> storageIdsWithSids;
        if (storageIds.isNotEmpty()) {
            Instant readBegin = Instant.now();
            storageIdsWithSids = srcShard.query(
                    "SELECT storage_id, stid FROM disk.duplicated_storage_files WHERE storage_id = ANY (:ids)",
                    (rs, i) -> Tuple2.tuple(rs.getObject("storage_id", UUID.class), rs.getString("stid")),
                    Cf.map("ids", PgArray.uuidArray(storageIds.toArray(new UUID[0])))
            );
            storageReadDuration
                    .accumulateAndGet(Duration.between(readBegin, Instant.now()), Duration::plus);
        } else {
            storageIdsWithSids = Cf.list();
        }
        return storageIdsWithSids;
    }

    private void copyFiles(DjfsCopyConfiguration migrationConf, PgSchema databaseSchema, JdbcTemplate3 srcShard,
            JdbcTemplate3 dstShard, Runnable callback)
    {
        AtomicReference<Duration> storageWriteDuration = new AtomicReference<>(Duration.ZERO);
        AtomicReference<Duration> storageReadDuration = new AtomicReference<>(Duration.ZERO);

        Instant beginFiles = Instant.now();
        AtomicReference<Duration> writeTimeFiles = new AtomicReference<>(Duration.ZERO);

        PgCursorUtils.queryWithCursorAsBatches(
                srcShard.getDataSource(), migrationConf.getBaseBatchSize(),
                "SELECT * FROM disk.files WHERE uid = ?",
                new ColumnMapRowMapper(),
                migrationConf.getUid().asLong()
        ).forEach(fetchedRows -> {
            ListF<UUID> storageIds = fetchedRows.map(m -> (UUID) m.get("storage_id"));
            //see DjfsAdditionalFileLinksMigration
            fetchedRows.forEach(DjfsFilesDataMigration::replaceIsLivePhotoWithNull);

            ListF<Map<String, Object>> storageFiles =
                    fetchStorageFiles(srcShard, storageReadDuration, storageIds);
            ListF<Tuple2<UUID, String>> storageDuplicates =
                    fetchStorageDuplicates(srcShard, storageReadDuration, storageIds);

            DjfsMigrationUtil.doInTransaction(dstShard.getDataSource(), () -> {
                copyStorageFiles(databaseSchema, dstShard, storageWriteDuration, storageFiles,
                        storageDuplicates);

                Instant beginWrite = Instant.now();
                DjfsMigrationUtil.copyRows(dstShard, databaseSchema, "files", fetchedRows);
                writeTimeFiles.accumulateAndGet(Duration.between(beginWrite, Instant.now()), Duration::plus);
            });
            callback.run();
        });

        Duration readDurationsStorageVersionData = storageReadDuration.get();
        Duration writeDurationsStorageVersionData = storageWriteDuration.get();
        Duration overallDurationStorageVersionData =
                readDurationsStorageVersionData.plus(writeDurationsStorageVersionData);
        logger.info("storage (files): overallTime {}, read time {}, write time {}",
                overallDurationStorageVersionData,
                readDurationsStorageVersionData,
                writeDurationsStorageVersionData
        );


        Duration overallTime = Duration.between(beginFiles, Instant.now()).minus(overallDurationStorageVersionData);
        Duration writeTimeSum = writeTimeFiles.get();
        logger.info("files: overallTime {}, read time {}, write time {}",
                overallTime,
                overallTime.minus(writeTimeSum),
                writeTimeSum
        );
    }

    private static void replaceIsLivePhotoWithNull(Map<String, Object> row) {
        if (row.get("is_live_photo") != null && (boolean) row.get("is_live_photo")) {
            row.put("is_live_photo", null);
        }
    }


    private void copyStorageFiles(PgSchema databaseSchema, JdbcTemplate3 dstShard,
            AtomicReference<Duration> storageWriteDuration, ListF<Map<String, Object>> fetchedStorageFiles,
            ListF<Tuple2<UUID, String>> previousStorageDuplicates)
    {
        if (fetchedStorageFiles.isNotEmpty()) {
            Instant writeBegin = Instant.now();

            String query = DjfsMigrationUtil.createInsertQuery(databaseSchema, "storage_files",
                    "INSERT INTO disk.{tableName} ({columnNames}) VALUES ({placeholders}) ON CONFLICT DO NOTHING"
            );

            DjfsMigrationUtil.withDisabledMigrationLockCheck(
                    dstShard,
                    () -> {
                        dstShard.batchUpdate(query, DjfsMigrationUtil
                                .prepareInsertArgs(databaseSchema, "storage_files", fetchedStorageFiles));

                        ListF<Tuple2<UUID, String>> storageIdsWithSids = fetchedStorageFiles.map(row ->
                                Tuple2.tuple((UUID) row.get("storage_id"), (String) row.get("stid"))
                        ).plus(previousStorageDuplicates);
                        dstShard.update(""
                                        + "WITH input(storage_id, stid) AS (SELECT unnest(?), unnest(?)),"
                                        + "local(storage_id, stid) AS ("
                                        + "     SELECT storage_id, stid FROM disk.storage_files "
                                        + "         WHERE storage_id = ANY(SELECT storage_id FROM input)"
                                        + "),"
                                        + "with_conflict(storage_id, stid) AS("
                                        + "     SELECT storage_id, stid FROM input "
                                        + "         WHERE (storage_id, stid) NOT IN (SELECT storage_id, stid FROM local)"
                                        + ")"
                                        + "INSERT INTO disk.duplicated_storage_files(storage_id, stid) "
                                        + "     SELECT storage_id, stid FROM with_conflict"
                                        + "     ON CONFLICT DO NOTHING",
                                PgArray.uuidArray(storageIdsWithSids.map(Tuple2::get1).toArray(UUID.class)),
                                PgArray.textArray(storageIdsWithSids.map(Tuple2::get2).toArray(String.class))
                        );
                    }
            );
            storageWriteDuration
                    .accumulateAndGet(Duration.between(writeBegin, Instant.now()), Duration::plus);
        }
    }

    @Override
    public void checkAllCopied(DjfsCopyConfiguration migrationConf, PgSchema pgSchema) {
        JdbcTemplate3 srcShard = migrationConf.srcShardJdbcTemplate();
        JdbcTemplate3 dstShard = migrationConf.dstShardJdbcTemplate();

        DjfsUid uid = migrationConf.getUid();

        DjfsMigrationUtil.checkSameDataInTableForUidWithBatches(
                srcShard, dstShard, "files", "fid", uid, migrationConf.getBaseBatchSize(), pgSchema
        );
        DjfsMigrationUtil.checkSameDataInTableForUid(srcShard, dstShard, "version_data", uid, pgSchema);
        // storage_files проверять бессмысленно, так как в случае дублей будет совпадать только storage_id,
        // а оно и так проверится при проверка files и version_data
    }


    @Override
    public void cleanData(JdbcTemplate3 shard, DjfsUid uid, int batchSize) {
        removeFiles(shard, uid, batchSize);
        removeVersionData(shard, uid, batchSize);
    }

    private void removeFiles(JdbcTemplate3 shard, DjfsUid uid, int batchSize) {
        boolean somethingDeleted;
        do {
            somethingDeleted = DjfsMigrationUtil.withDisabledMigrationLockCheck(shard, () -> {
                ListF<UUID> storageFilesToDelete = shard.query(""
                                + "WITH filesToDelete(uid, fid, storage_id) AS ("
                                + "     SELECT uid, fid, storage_id FROM disk.files "
                                + "         WHERE uid = :uid "
                                + "         LIMIT :batchSize"
                                + "),"
                                + "filesDel AS ("
                                + "     DELETE FROM disk.files WHERE (uid, fid) = ANY (SELECT uid, fid FROM filesToDelete)"
                                + "),"
                                + "versionDataDel AS ("
                                + "     DELETE FROM disk.version_data WHERE (uid, storage_id) = ANY (SELECT uid, storage_id FROM filesToDelete)"
                                + ")"
                                + "SELECT storage_id FROM filesToDelete",
                        (rs, i) -> ResultSetUtils.getUuid(rs, "storage_id"),
                        Cf.map(
                                "uid", uid.asLong(),
                                "batchSize", batchSize
                        )
                );
                deleteStorageFilesIfPossible(shard, storageFilesToDelete);
                return storageFilesToDelete.isNotEmpty();
            });
        } while (somethingDeleted);
    }

    /**
     * disk.files already cleaned
     */
    private void removeVersionData(JdbcTemplate3 shard, DjfsUid uid, int batchSize) {
        boolean somethingDeleted;
        do {
            somethingDeleted = DjfsMigrationUtil.withDisabledMigrationLockCheck(shard, () -> {
                ListF<Option<UUID>> wasDeleted = shard.query(""
                                + "WITH versionDataToDelete(id, storage_id) AS ("
                                + "     SELECT id, storage_id FROM disk.version_data "
                                + "         WHERE uid = :uid "
                                + "         LIMIT :batchSize"
                                + "),"
                                + "versionDataDel AS (DELETE FROM disk.version_data WHERE id = ANY (SELECT id FROM versionDataToDelete))"
                                + "SELECT storage_id FROM versionDataToDelete",
                        (rs, i) -> ResultSetUtils.getUuidO(rs, "storage_id"),
                        Cf.map(
                                "uid", uid.asLong(),
                                "batchSize", batchSize
                        ));
                ListF<UUID> storageFilesToDelete = wasDeleted.filter(Option::isPresent).map(Option::get);
                deleteStorageFilesIfPossible(shard, storageFilesToDelete);
                return wasDeleted.isNotEmpty();
            });
        } while (somethingDeleted);
    }

    private void deleteStorageFilesIfPossible(JdbcTemplate3 shard, ListF<UUID> toDelete) {
        if (toDelete.isNotEmpty()) {
            shard.update(""
                            + "WITH input(storage_id) AS (SELECT unnest(?)),"
                            + "to_delete(storage_id) AS ("
                            + "     SELECT input.storage_id FROM input "
                            + "         LEFT JOIN disk.files ON input.storage_id = files.storage_id "
                            + "         LEFT JOIN disk.version_data ON input.storage_id = version_data.storage_id "
                            + "     WHERE "
                            + "         files.storage_id ISNULL AND version_data.storage_id ISNULL"
                            + "),"
                            + "delete_duplicates AS ("
                            + "     DELETE FROM disk.duplicated_storage_files WHERE storage_id IN (SELECT storage_id FROM to_delete)"
                            + ")"
                            + "DELETE FROM disk.storage_files WHERE storage_id IN (SELECT storage_id FROM to_delete)",
                    PgArray.uuidArray(toDelete.toArray(new UUID[0]))
            );
        }
    }

    @Override
    public ListF<String> tables() {
        return Cf.list("storage_files", "files", "version_data", "duplicated_storage_files");
    }

}
