package ru.yandex.chemodan.app.djfs.migrator.migrations;

import java.sql.Array;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Spliterator;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import javax.sql.DataSource;

import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.jetbrains.annotations.NotNull;
import org.springframework.jdbc.core.ColumnMapRowMapper;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.chemodan.app.djfs.core.db.pg.PgCursorUtils;
import ru.yandex.chemodan.app.djfs.core.user.DjfsUid;
import ru.yandex.chemodan.app.djfs.core.util.StreamUtils;
import ru.yandex.chemodan.app.djfs.migrator.DjfsCopyConfiguration;
import ru.yandex.chemodan.app.djfs.migrator.PgSchema;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.spring.jdbc.ArgPreparedStatementSetter;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;
import ru.yandex.misc.test.Assert;

/**
 * @author yappo
 */
public class DjfsTableWithSelfReferenceMigration implements DjfsTableMigration {
    private static final Logger logger = LoggerFactory.getLogger(DjfsTableWithSelfReferenceMigration.class);

    private final String tableName;
    private final String referenceFrom;
    private final String referenceTo;

    public DjfsTableWithSelfReferenceMigration(String tableName, String referenceFrom, String referenceTo) {
        Assert.notEmpty(referenceFrom, "referenceFrom is empty");
        Assert.notEmpty(referenceTo, "referenceTo is empty");
        Assert.notEmpty(tableName, "table name is empty");

        this.referenceFrom = referenceFrom;
        this.referenceTo = referenceTo;
        this.tableName = tableName;
    }

    @Override
    public void runCopying(DjfsCopyConfiguration migrationConf, PgSchema databaseSchema,
            Runnable callback)
    {
        JdbcTemplate3 srcShard = migrationConf.srcShardJdbcTemplate();
        JdbcTemplate3 dstShard = migrationConf.dstShardJdbcTemplate();

        Instant begin = Instant.now();
        AtomicReference<Duration> writeTime = new AtomicReference<>(Duration.ZERO);

        Stream<Map<String, Object>> query = StreamSupport.stream(
                new TreeTableSpliterator(
                        srcShard.getDataSource(),
                        migrationConf.getBaseBatchSize(), migrationConf.getUid(),
                        tableDescription(databaseSchema)),
                false
        );
        StreamUtils.batches(query, migrationConf.getBaseBatchSize())
                .forEach(fetchedRows -> {
                    Instant beginWrite = Instant.now();
                    logger.info("copying {} rows of {}", fetchedRows.size(), tableName);
                    DjfsMigrationUtil.copyRows(dstShard, databaseSchema, tableName, fetchedRows);
                    writeTime.accumulateAndGet(Duration.between(beginWrite, Instant.now()), Duration::plus);
                    callback.run();
                });

        Duration overallTime = Duration.between(begin, Instant.now());
        logger.info("overallTime {}, read time {}, write time {}",
                overallTime, overallTime.minus(writeTime.get()), writeTime.get()
        );
    }

    @Override
    public void checkAllCopied(DjfsCopyConfiguration migrationConf,
            PgSchema sourceSchema) {
        JdbcTemplate3 srcShard = migrationConf.srcShardJdbcTemplate();
        JdbcTemplate3 dstShard = migrationConf.dstShardJdbcTemplate();

        DjfsMigrationUtil.checkSameDataInTableForUid(srcShard, dstShard, tableName, migrationConf.getUid(),
            sourceSchema);
    }


    @Override
    public void cleanData(JdbcTemplate3 shard, DjfsUid uid, int batchSize) {
        DjfsMigrationUtil.withDisabledMigrationLockCheck(shard, () ->
                shard.update("DELETE FROM disk." + tableName + " WHERE uid = ?", uid.asLong())
        );
    }

    @Override
    public ListF<String> tables() {
        return Cf.list(tableName);
    }

    private TableDescription tableDescription(PgSchema pgSchema) {
        return new TableDescription(
                tableName,
                referenceFrom,
                referenceTo,
                pgSchema.getTables().getTs(tableName)
                        .getColumns()
                        .filter(column -> column.getName().equals(referenceTo))
                        .first()
                        .getDataType()
        );
    }

    @Data
    @RequiredArgsConstructor
    private static class TableDescription {
        private final String name;
        private final String referenceFrom;
        private final String referenceTo;
        private final String dataType;
    }

    /**
     * Оптимизируем запись. Нужно, чтобы мы могли писать пачками без разбиения на уровни дерева.
     * При этом мы не можем забрать из базы сразу всё дерево через курсор, слишком много данных.
     * Поэтому мы собираем записи по уроням и объеденяем всё в один стрим.
     * На пачки для вставки стрим будет нарезаться уже снаружи.
     */
    private static class TreeTableSpliterator implements Spliterator<Map<String, Object>> {
        private Spliterator<Map<String, Object>> current;
        private ListF<Object> fetchedIds;
        private final TableDescription table;
        private final DataSource dataSource;
        private final int baseBatchSize;
        private final DjfsUid uid;

        public TreeTableSpliterator(DataSource dataSource, int baseBatchSize, DjfsUid uid, TableDescription table)
        {
            this.dataSource = dataSource;
            this.baseBatchSize = baseBatchSize;
            this.uid = uid;
            this.table = table;
            current = rootQuery(dataSource, baseBatchSize, table, uid).spliterator();
            fetchedIds = Cf.arrayList();
        }

        @Override
        public boolean tryAdvance(Consumer<? super Map<String, Object>> consumer) {
            boolean success = current.tryAdvance(((Consumer<Map<String, Object>>) this::saveId).andThen(consumer));
            if (!success) {
                if (fetchedIds.isEmpty()) {
                    return false;
                } else {
                    //restore state and try again
                    current = siblingsQuery(fetchedIds, baseBatchSize, dataSource, table, uid).spliterator();
                    fetchedIds = Cf.arrayList();
                    return tryAdvance(consumer);
                }
            } else {
                return true;
            }
        }

        @NotNull
        private static Stream<Map<String, Object>> rootQuery(DataSource dataSource, int batchSize,
                TableDescription table, DjfsUid uid) {
            return PgCursorUtils.queryWithCursor(
                    dataSource,
                    batchSize,
                    "SELECT * FROM disk." + table.name + " WHERE uid = ? AND " + table.referenceTo + " ISNULL",
                    new ColumnMapRowMapper(),
                    new ArgPreparedStatementSetter(new Object[]{uid.asLong()})
            );
        }

        @NotNull
        private static Stream<Map<String, Object>> siblingsQuery(ListF<Object> parentIds, int baseBatchSize,
                DataSource dataSource, TableDescription table, DjfsUid uid) {
            return StreamUtils.batches(parentIds.stream(), baseBatchSize)
                    .flatMap(batch -> PgCursorUtils.queryWithCursor(
                            dataSource, baseBatchSize,
                            "SELECT * FROM disk." + table.name + " WHERE uid = ? AND " + table.referenceTo + " = ANY(?)",
                            new ColumnMapRowMapper(),
                            ps -> {
                                ps.setLong(1, uid.asLong());
                                setArray(ps, 2, batch, table.dataType);
                            }
                    ));
        }

        private static void setArray(PreparedStatement ps, int index, ListF<Object> objects, String dataType) throws SQLException
        {
            Object[] arrayToRequest = objects.toArray(Object.class);
            Array array = ps.getConnection().createArrayOf(dataType, arrayToRequest);
            try {
                ps.setArray(index, array);
            } finally {
                array.free();
            }
        }

        private void saveId(Map<String, Object> row) {
            fetchedIds.add(row.get(table.referenceFrom));
        }

        @Override
        public Spliterator<Map<String, Object>> trySplit() {
            return null;
        }

        @Override
        public long estimateSize() {
            return 0;
        }

        @Override
        public int characteristics() {
            return current.characteristics();
        }
    }
}
