package ru.yandex.direct.mysql.ytsync.export.util;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import one.util.streamex.EntryStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.direct.mysql.ytsync.common.compatibility.KosherCompressorFactory;
import ru.yandex.direct.mysql.ytsync.common.compatibility.KosherFlatRowSerializer;
import ru.yandex.direct.mysql.ytsync.common.row.FlatRow;
import ru.yandex.direct.mysql.ytsync.common.util.ParallelRunner;
import ru.yandex.direct.mysql.ytsync.export.components.ConnectionsCache;
import ru.yandex.direct.mysql.ytsync.export.task.TableExportTask;
import ru.yandex.direct.mysql.ytsync.export.util.iterators.SqlLoaderToFlatRowIterator;
import ru.yandex.direct.mysql.ytsync.export.util.queue.ShardedQueue;
import ru.yandex.direct.mysql.ytsync.export.util.queue.ShardedQueueWithLimit;
import ru.yandex.direct.utils.Checked;
import ru.yandex.direct.utils.InterruptedRuntimeException;
import ru.yandex.direct.ytwrapper.YtUtils;
import ru.yandex.inside.yt.kosher.Yt;
import ru.yandex.inside.yt.kosher.common.GUID;
import ru.yandex.inside.yt.kosher.cypress.CypressNodeType;
import ru.yandex.inside.yt.kosher.cypress.YPath;
import ru.yandex.inside.yt.kosher.impl.transactions.utils.YtTransactionsUtils;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTreeBuilder;
import ru.yandex.inside.yt.kosher.tables.YTableEntryType;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.yt.ytclient.tables.ColumnValueType;

import static com.google.common.collect.Sets.intersection;
import static ru.yandex.direct.mysql.ytsync.export.util.RangesStateRecord.DB_NAME_YTREE_FIELD_NAME;
import static ru.yandex.direct.mysql.ytsync.export.util.RangesStateRecord.RANGES_REMAINING_YTREE_FIELD_NAME;

/**
 * Класс для импорта таблицы из базы Директа в статическую таблицу yt
 */
public class ShardedTableImporter {
    private static final Logger logger = LoggerFactory.getLogger(ShardedTableImporter.class);

    private final Yt yt;
    private final KosherCompressorFactory compressorFactory;
    private final ConnectionsCache connectionsCache;
    private final TableExportTask template;
    private final int chunkSize;
    private final ShardedQueue<IdRange> rangesQueue;
    private final List<ParallelRunner<?>> runners = new ArrayList<>();
    private final ObjectMapper objectMapper;
    private final YPath tablePath;
    private final MapF<String, YTreeNode> commonTableAttributes;
    private final int threadsCount;
    private final AtomicBoolean closed = new AtomicBoolean(false);

    /**
     * Периодичность сохранения промежуточного состояния импорта.
     */
    private final Duration flushInterval;

    public ShardedTableImporter(
            Yt yt,
            KosherCompressorFactory compressorFactory,
            ConnectionsCache connectionsCache,
            String dbNamePrefix,
            TableExportTask template,
            int chunkSize, long chunksLimit,
            YPath tablePath,
            MapF<String, YTreeNode> commonTableAttributes,
            int threadsCount,
            int flushIntervalMinutes) {
        this.yt = yt;
        this.compressorFactory = compressorFactory;
        this.connectionsCache = connectionsCache;
        this.template = template;
        this.chunkSize = chunkSize;
        this.tablePath = tablePath;
        this.commonTableAttributes = commonTableAttributes;
        this.threadsCount = threadsCount;
        this.flushInterval = Duration.ofMinutes(flushIntervalMinutes);
        //
        this.objectMapper = new ObjectMapper();
        this.rangesQueue = new ShardedQueueWithLimit(chunksLimit);
    }

    /**
     * Добавляет указанный шард в последующий импорт
     * <p>
     * Можно параллельно вызывать из нескольких потоков
     */
    private void addDbName(String dbName) throws SQLException {
        Connection conn = connectionsCache.getConnection(dbName);
        try {
            int idColumnIndex = template.getSchema().findColumn(template.getIdColumn());
            ColumnValueType columnType = template.getSchema().getColumnType(idColumnIndex);
            boolean idColumnHasStringLikeType = columnType.isStringLikeType() || columnType == ColumnValueType.UINT64;
            TableIdRangeEstimator estimator =
                    new TableIdRangeEstimator(conn, template.getTableName(), template.getIdColumn(),
                            idColumnHasStringLikeType);
            // Получаем набор диапазонов примерно по chunkSize строк
            List<IdRange> ranges = estimator.estimateTableRanges(chunkSize);
            if (!ranges.isEmpty()) {
                // Добавляем в очередь соответствующего шарда
                rangesQueue.add(dbName, ranges);
                logger.info("Generated {} ranges for table {} from {}", ranges.size(), template.getTableName(), dbName);
            } else {
                logger.info("No ranges for table {} from {}", template.getTableName(), dbName);
            }
        } finally {
            connectionsCache.releaseConnection(dbName, conn);
        }
    }

    private void addShards(List<String> dbNames) throws InterruptedException {
        if (!dbNames.isEmpty()) {
            ParallelRunner<String> parallelRunner = new ParallelRunner<>(dbNames,
                    dbName -> {
                        try {
                            addDbName(dbName);
                        } catch (SQLException e) {
                            Checked.throwWrapped(e);
                        }
                    });
            runners.add(parallelRunner);
            parallelRunner.run();
            runners.remove(parallelRunner);
        }
    }

    private YPath getStateTablePath() {
        return YPath.simple(tablePath.toString() + "_state");
    }

    public void run(List<String> dbNames) throws InterruptedException {
        if (threadsCount <= 0) {
            throw new IllegalArgumentException("threadsCount must be greater than zero");
        }

        logger.info("Importing table {} to {}", template.getTableName(), tablePath);

        // Загружаем из YT диапазоны, которые нужно импортировать, или же загружаем их из БД (если таблицы нет)
        initRangesQueue(dbNames);

        List<ShardedSqlLoader> shardedSqlLoaders = Collections.synchronizedList(new ArrayList<>());
        try {
            for (int i = 0; i < threadsCount; ++i) {
                shardedSqlLoaders.add(
                        new ShardedSqlLoader(
                                connectionsCache,
                                rangesQueue,
                                template.getSqlTemplate(),
                                template.getExportFlatRowCreator()
                        ));
            }

            YtTransactionsUtils.withTransaction(
                    yt,
                    Duration.ofMinutes(1),
                    Optional.of(Duration.ofSeconds(5)),
                    tx -> {
                        // Создаём таблицу, если она ещё не существует
                        createDataTableIfNotExists(tx.getId());
                        saveRangesQueueState(tx.getId());
                        return null;
                    });

            while (!rangesQueue.isEmpty()) {
                if (closed.get()) {
                    throw new RuntimeException("Import aborted from outside");
                }

                for (ShardedSqlLoader sqlLoader : shardedSqlLoaders) {
                    sqlLoader.resume();
                }

                // Выполняем всё в транзакции с коротким таймаутом и пингами каждые 5 секунд
                // В этом случае транзакции не будут слишком долго висеть после некорректного завершения
                YtTransactionsUtils.withTransaction(
                        yt,
                        Duration.ofMinutes(1),
                        Optional.of(Duration.ofSeconds(5)),
                        tx -> {
                            try {
                                ProcessingSpeedCounter counter = new ProcessingSpeedCounter(rangesQueue);
                                ParallelRunner<ShardedSqlLoader> parallelRunner = null;
                                try {
                                    parallelRunner = new ParallelRunner<>(shardedSqlLoaders,
                                            shardedSqlLoader -> runWorker(shardedSqlLoader, tablePath, tx.getId(),
                                                    counter));
                                    runners.add(parallelRunner);
                                    counter.setParallelRunner(parallelRunner);

                                    // Таймер запаузит чтение чанков из таблиц через некоторое время, чтобы
                                    // сохранить достигнутое состояние
                                    Timer flushTimer = new Timer("table-importer-flush-timer", true);
                                    flushTimer.schedule(new TimerTask() {
                                        @Override
                                        public void run() {
                                            // По индексу, потому что iterator() не является synchronized
                                            for (int i = 0; i < shardedSqlLoaders.size(); i++) {
                                                ShardedSqlLoader sqlLoader = shardedSqlLoaders.get(i);
                                                sqlLoader.pause();
                                            }
                                        }
                                    }, flushInterval.toMillis());

                                    // Ждём завершения выполнения потоков. Они остановятся, когда вся работа будет
                                    // сделана,
                                    // либо если flushTimer сказал им остановиться, и они доделали текущие чанки.
                                    parallelRunner.run();

                                    // Здесь в целом не нужна строгая консистентность, потому что если вся работа
                                    // выполнена,
                                    // но flushTimer не сработал, то ничего страшного не произойдёт:
                                    // он всего лишь позовёт pause() у всех загрузчиков. А во всех остальных случаях
                                    // таймер сработает раньше, чем работа завершится сама собой.
                                    flushTimer.cancel();

                                    // Сохраняем в YT стейт текущей rangesQueue
                                    saveRangesQueueState(tx.getId());
                                } finally {
                                    if (parallelRunner != null) {
                                        runners.remove(parallelRunner);
                                    }
                                    counter.stop();
                                }
                            } catch (InterruptedException e) {
                                Thread.currentThread().interrupt();
                                Checked.throwWrapped(e);
                            }
                            return null;
                        });
            }
        } finally {
            for (ShardedSqlLoader shardedSqlLoader : shardedSqlLoaders) {
                shardedSqlLoader.releaseResources();
            }
        }

        logger.info("Finished import table {} to {}", template.getTableName(), tablePath);
    }

    private void createDataTableIfNotExists(GUID transactionId) {
        if (!yt.cypress().exists(tablePath)) {
            Map<String, YTreeNode> attributes = Cf.hashMap();
            attributes.putAll(commonTableAttributes);
            if (!attributes.containsKey(YtUtils.SCHEMA_ATTR)) {
                attributes.put(YtUtils.SCHEMA_ATTR, template.getStaticSchema().toYTree());
            }
            yt.cypress().create(
                    Optional.of(transactionId),
                    true,
                    tablePath,
                    CypressNodeType.TABLE,
                    true,
                    true,
                    attributes);
        }
    }

    /**
     * Если есть таблица с состоянием диапазонов, то загружаем состояние из неё.
     * В противном случае загружаем диапазоны из БД.
     */
    private void initRangesQueue(List<String> dbNames) throws InterruptedException {
        // Если нет основной таблицы -- грузим с нуля
        if (!yt.cypress().exists(tablePath)) {
            logger.info("Table {} not found, loading ranges from database", tablePath);
            addShards(dbNames);
            return;
        }
        // Если нет таблицы с состоянием -- тоже
        YPath stateTablePath = getStateTablePath();
        if (!yt.cypress().exists(stateTablePath)) {
            logger.info("Table {} not found, loading ranges from database", stateTablePath);
            addShards(dbNames);
            return;
        }
        Map<String, List<IdRange>> rangesQueueState = new ConcurrentHashMap<>();
        try {
            yt.tables().read(stateTablePath, RangesStateRecord.YT_TYPE, record -> {
                List<IdRange> ranges =
                        fromJson(record.getRangesRemaining(), new TypeReference<List<IdRange>>() {
                        });
                rangesQueueState.put(record.getDbName(), ranges);
            });
        } catch (Exception e) {
            logger.error(String.format(
                    "Unexpected exception when restoring ranges state from %s, loading ranges from database",
                    stateTablePath), e);
            addShards(dbNames);
            return;
        }

        logger.info("Loaded ranges from {}, ranges count by shards: {}",
                stateTablePath, EntryStream.of(rangesQueueState).mapValues(List::size).toMap());
        if (intersection(rangesQueueState.keySet(), new HashSet<>(dbNames)).size() != rangesQueueState.keySet().size()) {
            logger.error(
                    "Error restoring ranges state from {}: loaded shards {} but required shards are {}, ignoring state",
                    stateTablePath, rangesQueueState.keySet(), dbNames);
            addShards(dbNames);
            return;
        }
        rangesQueue.loadFromMap(rangesQueueState);
    }

    private void saveRangesQueueState(GUID transaction) {
        logger.info("Saving current import state. Chunks remaining: {}", rangesQueue.getChunksCount());
        //
        Map<String, List<IdRange>> rangesQueueState = rangesQueue.copyToMap();
        List<RangesStateRecord> rangesStateRecords = EntryStream.of(rangesQueueState)
                .mapValues(this::toJson)
                // fix possible 'Sort order violation'
                .sorted(Map.Entry.comparingByKey())
                .mapKeyValue(RangesStateRecord::new)
                .toList();
        //
        YPath tablePath = getStateTablePath();
        Map<String, YTreeNode> attributes = Cf.hashMap();
        attributes.putAll(commonTableAttributes);
        if (!attributes.containsKey(YtUtils.SCHEMA_ATTR)) {
            attributes.put(YtUtils.SCHEMA_ATTR, createStateTableSchema());
        }
        if (yt.cypress().exists(tablePath)) {
            yt.cypress().remove(tablePath);
        }
        yt.cypress().create(
                Optional.of(transaction),
                true,
                tablePath,
                CypressNodeType.TABLE,
                true,
                false,
                attributes);
        yt.tables().write(
                Optional.of(transaction),
                true,
                tablePath,
                RangesStateRecord.YT_TYPE,
                rangesStateRecords.iterator()
        );
    }

    private String toJson(Object object) {
        try {
            return objectMapper.writeValueAsString(object);
        } catch (JsonProcessingException e) {
            throw new UncheckedIOException(e);
        }
    }

    private <T> T fromJson(String string, TypeReference typeReference) {
        try {
            return objectMapper.readValue(string, typeReference);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private static YTreeNode createStateTableSchema() {
        YTreeBuilder schemaBuilder = YTree.listBuilder();
        schemaBuilder.beginMap()
                .key("name").value(DB_NAME_YTREE_FIELD_NAME)
                .key("type").value("string")
                .key("sort_order").value("ascending")
                .key("required").value(false)
                .endMap();
        schemaBuilder.beginMap()
                .key("name").value(RANGES_REMAINING_YTREE_FIELD_NAME)
                .key("type").value("string")
                .key("required").value(false)
                .endMap();
        return YTree.builder()
                .beginAttributes()
                .key("strict").value(true)
                .key("unique_keys").value(true)
                .endAttributes()
                .value(schemaBuilder.buildList())
                .build();
    }

    private void runWorker(ShardedSqlLoader shardedSqlLoader,
                           YPath tablePath, GUID transactionId, ProcessingSpeedCounter counter) {
        try {
            YTableEntryType<FlatRow> entryType =
                    new KosherFlatRowSerializer<FlatRow>(template.getWriteSchema()).toEntryType();
            Iterator<FlatRow> rows =
                    counter.spyWithInterrupts(SqlLoaderToFlatRowIterator.fromStream(shardedSqlLoader));
            yt.tables().write(
                    Optional.of(transactionId),
                    true,
                    tablePath.append(true),
                    entryType,
                    rows,
                    compressorFactory.createCompressor());
        } catch (InterruptedRuntimeException e) {
            Checked.throwWrapped(e);
        } catch (Exception e) {
            logger.error("Import worker exception", e);
            Checked.throwWrapped(e);
        }
    }

    public void stop() {
        logger.info("Stopping sharded table exporter {}", this);
        closed.set(true);
        runners.forEach(ParallelRunner::stop);
    }

    /**
     * Удаляет таблицу с промежуточным состоянием импорта (если она есть).
     */
    public void deleteStateTable() {
        YPath stateTablePath = getStateTablePath();
        if (yt.cypress().exists(stateTablePath)) {
            logger.info("Deleting table {}", stateTablePath);
            yt.cypress().remove(stateTablePath);
            logger.info("Table {} deleted", stateTablePath);
        }
    }
}
