package ru.yandex.chemodan.app.djfs.migrator.migrations;

import java.math.BigDecimal;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import javax.sql.DataSource;

import com.google.common.annotations.VisibleForTesting;
import org.jetbrains.annotations.NotNull;
import org.springframework.jdbc.core.ColumnMapRowMapper;
import org.springframework.transaction.support.TransactionTemplate;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.CollectionF;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.bolts.collection.Tuple2List;
import ru.yandex.bolts.function.Function;
import ru.yandex.bolts.function.Function0;
import ru.yandex.bolts.function.Function0V;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgCursorUtils;
import ru.yandex.chemodan.app.djfs.core.user.DjfsUid;
import ru.yandex.chemodan.app.djfs.migrator.PgSchema;
import ru.yandex.misc.db.postgres.PgBouncerFamiliarTransactionManager;
import ru.yandex.misc.spring.jdbc.ArgPreparedStatementSetter;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;
import ru.yandex.misc.test.Assert;

public class DjfsMigrationUtil {
    private static final int DIFFERENCE_SAMPLES_SIZE = 10;

    static void copyRows(JdbcTemplate3 dstShard, PgSchema databaseSchema,
            String tableName, ListF<Map<String, Object>> fetchedRows)
    {
        if (fetchedRows.isEmpty()) {
            return;
        }
        String query = createInsertQuery(databaseSchema, tableName,
                "INSERT INTO disk.{tableName} ({columnNames}) VALUES ({placeholders})"
        );

        withDisabledMigrationLockCheck(
                dstShard,
                () -> dstShard.batchUpdate(query, prepareInsertArgs(databaseSchema, tableName, fetchedRows))
        );
    }

    static ListF<ListF<Object>> prepareInsertArgs(PgSchema databaseSchema, String tableName,
            ListF<Map<String, Object>> fetchedRows)
    {
        PgSchema.Table table = databaseSchema.getTables().getTs(tableName);
        return fetchedRows.map(row -> table.getColumns().map(PgSchema.Column::getName).map(row::get));
    }

    @NotNull
    static String createInsertQuery(PgSchema databaseSchema, String tableName, String template) {
        PgSchema.Table table = databaseSchema.getTables().getTs(tableName);
        CollectionF<PgSchema.Column> columns = table.getColumns();

        return template
                .replace("{tableName}", tableName)
                .replace("{columnNames}", columns.map(PgSchema.Column::getName).mkString(", "))
                .replace("{placeholders}", columns.map(DjfsMigrationUtil::placeholderForColumn).mkString(", "));
    }

    static void withDisabledMigrationLockCheck(JdbcTemplate3 shard, Runnable block) {
        doInTransaction(shard.getDataSource(), () -> {
            shard.execute("set local app.disable_migration_lock to true");
            block.run();
        });
    }

    static <T> T withDisabledMigrationLockCheck(JdbcTemplate3 shard, Supplier<T> block) {
        return doInTransaction(shard.getDataSource(), () -> {
            shard.execute("set local app.disable_migration_lock to true");
            return block.get();
        });
    }

    @NotNull
    private static String placeholderForColumn(PgSchema.Column c) {
        if ("USER-DEFINED".equals(c.getDataType())) {
            return "?::" + c.getUserDefinedTypeSchema() + '.' + c.getUserDefinedTypeName();
        } else {
            return "?";
        }
    }

    static void checkSameDataInTableForUid(JdbcTemplate3 srcShard, JdbcTemplate3 dstShard, String tableName,
            DjfsUid uid, PgSchema pgSchema)
    {
        String checksumQuery = checksumQuery(tableName, pgSchema);
        checkSameDataInTable(srcShard, dstShard, tableName, uid, pgSchema, jdbcTemplate3 ->
                jdbcTemplate3.queryForObject(
                        "SELECT sum(" + checksumQuery + ") FROM disk." + tableName + " as t WHERE t.uid = ?",
                        String.class, uid.asLong()
                )
        );
    }

    static void checkSameDataInTableForUidWithBatches(JdbcTemplate3 srcShard, JdbcTemplate3 dstShard, String tableName,
            String sortByColumn, DjfsUid uid, int batchSize, PgSchema pgSchema)
    {
        String checksumQuery = checksumQuery(tableName, pgSchema);
        checkSameDataInTable(srcShard, dstShard, tableName, uid, pgSchema, jdbcTemplate3 -> {
            Option<UUID> maxToSort = Option.empty();
            BigDecimal checksum = BigDecimal.ZERO;
            do {
                Tuple2<Option<UUID>, Option<BigDecimal>> result = jdbcTemplate3.query(
                        "SELECT max(" + sortByColumn + "::text) as maxToSort, sum(" + checksumQuery + ") as checksum "
                                + " FROM (SELECT * FROM disk." + tableName
                                + " WHERE uid = :uid AND CASE WHEN :maxToSort::text ISNULL THEN true ELSE " + sortByColumn + " > :maxToSort END"
                                + " ORDER BY " + sortByColumn + " LIMIT :batchSize) as tmp",
                        (rs, i) -> Tuple2.tuple(
                                Option.ofNullable(rs.getString("maxToSort")).map(UUID::fromString),
                                Option.ofNullable(rs.getBigDecimal("checksum"))
                        ),
                        Cf.map(
                                "uid", uid,
                                "batchSize", batchSize,
                                "maxToSort", maxToSort.getOrNull()
                        )
                ).first();
                maxToSort = result.get1();
                if (result.get2().isPresent()) {
                    checksum = checksum.add(result.get2().get());
                }
            } while (maxToSort.isPresent());

            return checksum.toString();
        });
    }

    @NotNull
    private static String textQuery(String tableName, PgSchema pgSchema) {
        String columnNames = pgSchema.getTables().getTs(tableName).getColumns()
                .stream()
                .map(PgSchema.Column::getName)
                .collect(Collectors.joining(","));
        // (" + columnNames + ")::text берём все колонки (в виде record) и кастим к тексту.
        // Можно было f::text, но на всяк пожарный явный порядок
        return "(" + columnNames + ")::text";
    }

    @NotNull
    private static String checksumQuery(String tableName, PgSchema pgSchema) {
        String columns = textQuery(tableName, pgSchema);
        // encode(md5(...)::bytea, 'hex') берём текст, хешируем, переводим в набор байтов и превращаем hex
        //
        // ('x'||lpad(..., 16, '0'))::bit(64)::bigint к hex приставляем 'x', обрезаем до 16 симоволов и превращаем в long
        return "('x'||lpad(encode(md5(" + columns + ")::bytea, 'hex'), 16, '0'))::bit(64)::bigint";
    }

    private static void checkSameDataInTable(JdbcTemplate3 srcShard, JdbcTemplate3 dstShard, String tableName,
            DjfsUid uid, PgSchema pgSchema, Function<JdbcTemplate3, String> counter)
    {
        String onSource = counter.apply(srcShard);
        String onDestination = counter.apply(dstShard);
        if (!Objects.equals(onSource, onDestination)) {
            Tuple2List<Map<String, Object>, Map<String, Object>> difference =
                    calculateDifferenceOnShards(srcShard, dstShard, tableName, uid, pgSchema);
            Assert.isEmpty(difference, "different records in table " + tableName);
        }
    }

    @SuppressWarnings("unchecked")
    @VisibleForTesting
    public static Tuple2List<Map<String, Object>, Map<String, Object>> calculateDifferenceOnShards(
            JdbcTemplate3 srcShard, JdbcTemplate3 dstShard, String tableName, DjfsUid uid, PgSchema pgSchema
    )
    {
        // using indexes for ordering on large tables
        String orderByField;
        String fields;
        Comparator<Map<String, Object>> comparator;
        if (tableName.equals("files") || tableName.equals("folders")) {
            orderByField = "fid";
            comparator = Comparator.comparing(row -> (Comparable<Object>) row.get("fid"));
            fields = "fid,";
        } else {
            orderByField = "checksum";
            comparator = Comparator.comparing(row -> (Comparable<Object>) row.get("checksum"));
            fields = "";
        }

        String checksumQuery = checksumQuery(tableName, pgSchema);
        String textQuery = textQuery(tableName, pgSchema);
        String query = "select " + fields + " " + textQuery + " as stringRow, " + checksumQuery + " as checksum"
                       + " from disk." + tableName
                       + " where uid = ? order by " + orderByField;


        return PgCursorUtils.queryWithCursor(srcShard.getDataSource(), 50000,
                query, new ColumnMapRowMapper(), new ArgPreparedStatementSetter(new Object[]{uid.asLong()}),
                srcIterator -> PgCursorUtils.queryWithCursor(dstShard.getDataSource(), 50000,
                        query, new ColumnMapRowMapper(), new ArgPreparedStatementSetter(new Object[]{uid.asLong()}),
                        dstIterator -> difference(comparator, srcIterator, dstIterator)
                )
        );
    }

    private static Tuple2List<Map<String, Object>, Map<String, Object>> difference(
            Comparator<Map<String, Object>> comparator, Iterator<Map<String, Object>> srcIterator,
            Iterator<Map<String, Object>> dstIterator)
    {
        Tuple2List<Map<String, Object>, Map<String, Object>> difference = Tuple2List.arrayList();

        if (!srcIterator.hasNext() || !dstIterator.hasNext()) {
            toStream(dstIterator).limit(DIFFERENCE_SAMPLES_SIZE).forEach(item -> difference.add(null, item));
            toStream(srcIterator).limit(DIFFERENCE_SAMPLES_SIZE).forEach(item -> difference.add(item, null));
            return difference;
        }

        Option<Map<String, Object>> srcCurrent = nextOrEmpty(srcIterator);
        Option<Map<String, Object>> dstCurrent = nextOrEmpty(dstIterator);
        do {
            if (!srcCurrent.isPresent() || !dstCurrent.isPresent()) {
                difference.add(srcCurrent.getOrNull(), dstCurrent.getOrNull());
                srcCurrent = nextOrEmpty(srcIterator);
                dstCurrent = nextOrEmpty(dstIterator);
                continue;
            }

//            if (!srcCurrent.get().get("checksum").equals(dstCurrent.get().get("checksum"))) {
                if (!srcCurrent.get().get("stringRow").equals(dstCurrent.get().get("stringRow"))) {
                    difference.add(srcCurrent.get(), dstCurrent.get());
                }
//            }

            int compare = comparator.compare(srcCurrent.get(), dstCurrent.get());
            if (compare < 0) {
                srcCurrent = nextOrEmpty(srcIterator);
            } else if (compare > 0) {
                dstCurrent = nextOrEmpty(dstIterator);
            } else {
                srcCurrent = nextOrEmpty(srcIterator);
                dstCurrent = nextOrEmpty(dstIterator);
            }
        } while ((srcCurrent.isPresent() || dstCurrent.isPresent()) && difference.size() <= DIFFERENCE_SAMPLES_SIZE);
        return difference;
    }

    @NotNull
    private static <T> Stream<T> toStream(Iterator<T> dstIterator) {
        Iterable<T> iter = () -> dstIterator;
        return StreamSupport.stream(iter.spliterator(), false);
    }

    @NotNull
    private static Option<Map<String, Object>> nextOrEmpty(Iterator<Map<String, Object>> iterator) {
        if (iterator.hasNext()) {
            return Option.of(iterator.next());
        } else {
            return Option.empty();
        }
    }

    public static void doInTransaction(DataSource dataSource, Function0V action) {
        new TransactionTemplate(new PgBouncerFamiliarTransactionManager(dataSource))
                .execute(action.asFunctionReturnParam()::apply);
    }

    public static <R> R doInTransaction(DataSource dataSource, Function0<R> action) {
        return new TransactionTemplate(new PgBouncerFamiliarTransactionManager(dataSource))
                .execute(action.asFunction()::apply);
    }
}
