package ru.yandex.direct.mysql.ytsync.export.util;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.function.Function;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.utils.InterruptedRuntimeException;

import static java.util.stream.Collectors.toMap;
import static ru.yandex.direct.utils.FunctionalUtils.mapList;

class TableIdRangeEstimator {
    private static final Logger logger = LoggerFactory.getLogger(TableIdRangeEstimator.class);

    private final Connection conn;
    private final String tableName;
    private final String idColumn;
    private final boolean idColumnHasStringLikeType;

    TableIdRangeEstimator(Connection conn, String tableName, String idColumn,
                          boolean idColumnHasStringLikeType) {
        this.conn = conn;
        this.tableName = tableName;
        this.idColumn = idColumn;
        this.idColumnHasStringLikeType = idColumnHasStringLikeType;
    }

    /**
     * Оценивает кол-во строк в таблице tableName с idColumn в диапазоне [minId, maxId]
     */
    private long estimateRowsCount(long minId, long maxId) throws SQLException {
        String explainSql =
                "EXPLAIN SELECT 1 FROM " + tableName + " WHERE " + idColumn + " BETWEEN " + minId + " AND " + maxId;
        logger.debug("Estimating table {} column {} range [{},{}]", tableName, idColumn, minId, maxId);
        try (PreparedStatement stmt = conn.prepareStatement(explainSql)) {
            try (ResultSet rs = stmt.executeQuery()) {
                if (!rs.next()) {
                    throw new IllegalStateException("Unexpected empty ResultSet");
                }
                long count = rs.getLong("rows");
                logger.debug("Estimating table {} column {} range [{},{}]: {} rows", tableName, idColumn, minId, maxId,
                        count);
                return count;
            }
        }
    }

    private TableIdRange createRange(long minId, long maxId) throws SQLException {
        if (maxId < minId) {
            throw new IllegalArgumentException("Cannot create inverse range [" + minId + "," + maxId + "]");
        }
        return new TableIdRange(minId, maxId, estimateRowsCount(minId, maxId));
    }

    /**
     * Делит таблицу tableName на диапазоны в рамках колонки idColumn примерно по chunkSize строк
     */
    List<IdRange> estimateTableRanges(long chunkSize) throws SQLException {
        if (idColumnHasStringLikeType) {
            return getStringIdRanges(chunkSize);
        } else {
            return getLongIdRanges(chunkSize);
        }
    }

    private List<IdRange> getLongIdRanges(long chunkSize) throws SQLException {
        long overallMinId, overallMaxId;
        String getMinMaxSql = "SELECT MIN(" + idColumn + "), MAX(" + idColumn + ") FROM " + tableName;
        try (PreparedStatement stmt = conn.prepareStatement(getMinMaxSql)) {
            try (ResultSet rs = stmt.executeQuery()) {
                if (!rs.next()) {
                    throw new IllegalStateException("Unexpected empty ResultSet");
                }
                overallMinId = rs.getLong(1);
                if (rs.wasNull()) {
                    return new ArrayList<>();
                }
                overallMaxId = rs.getLong(2);
                if (rs.wasNull()) {
                    return new ArrayList<>();
                }
            }
        }
        List<TableIdRange> ranges = new ArrayList<>();
        List<TableIdRange> pending = new ArrayList<>();
        pending.add(createRange(overallMinId, overallMaxId));
        while (!pending.isEmpty()) {
            if (Thread.interrupted()) {
                // Позволяем прервать выполнение
                Thread.currentThread().interrupt();
                throw new InterruptedRuntimeException();
            }
            TableIdRange current = pending.get(pending.size() - 1);
            pending.remove(pending.size() - 1);
            if (current.getCount() <= chunkSize || current.getMinId() >= current.getMaxId()) {
                ranges.add(current);
                continue;
            }
            long mid = current.getMinId() + (current.getMaxId() - current.getMinId() + 1) / 2;
            pending.add(createRange(current.getMinId(), mid - 1));
            pending.add(createRange(mid, current.getMaxId()));
        }
        // Сортируем получившиеся диапазоны и делаем небольшую компрессию
        Collections.sort(ranges);
        ranges = TableIdRange.compressList(ranges, chunkSize);
        // Сортируем от больших к меньшем (в конце останутся хвосты меньшего размера)
        ranges.sort(Comparator.comparing(TableIdRange::getCount).reversed());
        return mapList(ranges, r -> new IdRange(String.valueOf(r.getMinId()), String.valueOf(r.getMaxId()),
                r.getCount()));
    }

    /**
     * Пример для таблицы banner_images_formats и chunkSize = 1_000_000 на dt:ppc:15
     * SELECT rownum, image_hash FROM (
     * SELECT @row := @row + 1 AS rownum, image_hash
     * FROM (SELECT @row := 0) r, banner_images_formats ORDER BY image_hash) ranked
     * WHERE rownum % 1000000 in (0, 1)
     * <p>
     * +---------+------------------------+
     * | rownum  | image_hash             |
     * +---------+------------------------+
     * |       1 | ----4QKEI-vOkD7O9xn6QA |
     * | 1000000 | jNtCZjao2Lw3oLuD4ns69w |
     * | 1000001 | jnTDBdq2d3eUuYjpSEYQIA |
     * | 2000000 | YRtGAgpMh4XAxIqMxfe5Nw |
     * | 2000001 | yrTGdXasF1YXhOSHbcYNdw |
     * +---------+------------------------+
     *
     * @param chunkSize размер пачки записей
     */
    private List<IdRange> getStringIdRanges(long chunkSize) throws SQLException {
        String overallMaxId;
        String getMaxSql = "SELECT MAX(" + idColumn + ") FROM " + tableName;
        try (PreparedStatement stmt = conn.prepareStatement(getMaxSql)) {
            try (ResultSet rs = stmt.executeQuery()) {
                if (!rs.next()) {
                    throw new IllegalStateException("Unexpected empty ResultSet");
                }
                overallMaxId = rs.getString(1);
                if (rs.wasNull()) {
                    return new ArrayList<>();
                }
            }
        }
        String getRangesSql =
                "SELECT " + idColumn + " FROM ("
                        + "SELECT @row := @row + 1 AS rownum, " + idColumn + " "
                        + "FROM (SELECT @row := 0) r, " + tableName + " ORDER BY " + idColumn + ") ranked "
                        + "WHERE rownum % " + chunkSize + " in (0, 1)";
        List<IdRange> ranges = new ArrayList<>();
        try (PreparedStatement stmt = conn.prepareStatement(getRangesSql)) {
            try (ResultSet rs = stmt.executeQuery()) {
                while (rs.next()) {
                    String minId = rs.getString(1);
                    String maxId;
                    if (rs.next()) {
                        maxId = rs.getString(1);
                    } else {
                        maxId = overallMaxId;
                    }
                    ranges.add(new IdRange(minId, maxId, chunkSize));
                }
            }
        }
        // в некоторых таблицах (mod_edit, checksum) первое поле первичного ключа имеет единственное значение
        // а записей в таблицах может быть больше chunkSize
        // поэтому нужно уметь объединять одинаковые по getMinId и getMaxId строки
        ranges = new ArrayList<>(
                ranges.stream()
                        .collect(toMap(range -> Arrays.asList(range.getMaxId(), range.getMinId()), Function.identity(),
                                (r1, r2) -> r1))
                        .values());
        return ranges;
    }

}
