package ru.yandex.msearch.proxy.api.async.mail.subscriptions.update.dao;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.stream.Collectors;

import ru.yandex.logger.PrefixedLogger;
import ru.yandex.msearch.proxy.api.async.mail.subscriptions.update.dao.pojo.MigrationTask;
import ru.yandex.msearch.proxy.api.async.mail.subscriptions.update.dao.pojo.MigrationTaskStatus;
import ru.yandex.search.msal.pool.DBConnectionPool;

public class MigrationsTasksPostgresDao {
    private static final long TASK_TTL_MS = TimeUnit.DAYS.toMillis(7);
    private static final long RUNNING_TASK_TTL_MS = 10000;
    private final static int RETRY_COUNT = 20;
    private final static long RETRY_LAG_MS = 500;

    private final DBConnectionPool connectionPool;
    private final PrefixedLogger logger;

    public MigrationsTasksPostgresDao(DBConnectionPool connectionPool, PrefixedLogger logger) {
        this.connectionPool = connectionPool;
        this.logger = logger;
    }

    public List<MigrationTask> findPendingTasks() {
        return findByQuery(connection -> connection.prepareStatement(String.format(
                "SELECT * FROM migrations" +
                        " WHERE status = '%s'" +
                        " OR (status = '%s' AND lastUpdate < %d)" +
                        " OR (status = '%s' AND retryCount < %d AND lastUpdate < (%d - %d * 2^retryCount))" +
                        " ORDER BY timestamp;",
                MigrationTaskStatus.CREATED.value(),
                MigrationTaskStatus.RUNNING.value(),
                System.currentTimeMillis() - RUNNING_TASK_TTL_MS,
                MigrationTaskStatus.FAILED.value(),
                RETRY_COUNT,
                System.currentTimeMillis(),
                RETRY_LAG_MS
        )));
    }

    public List<MigrationTask> findAll() {
        return findByQuery(connection -> connection.prepareStatement("SELECT * FROM migrations ORDER BY timestamp;"));
    }

    public void deleteExpiredTasks() {
        long minTimestamp = System.currentTimeMillis() - TASK_TTL_MS;
        execute(connection -> connection.prepareStatement("DELETE FROM migrations WHERE timestamp < " + minTimestamp + ";"));
    }

    public void updateStatus(MigrationTask task, MigrationTaskStatus newStatus) {
        updateStatus(task, newStatus, task.getRetryCount());
    }
    public void updateStatus(MigrationTask task, MigrationTaskStatus newStatus, int retryCount) {
        execute(connection -> {
            PreparedStatement statement = connection.prepareStatement(
                    "UPDATE migrations " +
                            "SET status = ?, lastUpdate = ?, retryCount = ? " +
                            "WHERE timestamp = ? " +
                            "AND uid = ? " +
                            "AND email = ?;"
            );
            statement.setString(1, newStatus.value());
            statement.setLong(2, System.currentTimeMillis());
            statement.setInt(3, retryCount);
            statement.setLong(4, task.getTimestamp());
            statement.setLong(5, task.getUid());
            statement.setString(6, task.getEmail());
            return statement;
        });
    }

    public Stat getStats(List<Integer> percentiles, long periodMs) {
        return new Stat(
                getTaskCounts(),
                getTimings(percentiles, periodMs)
        );
    }

    private Map<MigrationTaskStatus, Integer> getTaskCounts() {
        ResultSet resultSet = executeQuery(connection -> connection.prepareStatement(
                "SELECT status, COUNT(email) as cnt" +
                        " FROM migrations" +
                        " GROUP BY status;"
        ));
        Map<MigrationTaskStatus, Integer> result = new HashMap<>();
        while (nextSafe(resultSet)) {
            try {
                result.put(
                        MigrationTaskStatus.of(resultSet.getString("status")),
                        resultSet.getInt("cnt")
                );
            } catch (SQLException e) {
                logger.log(Level.WARNING, "Error parsing entry: ", e);
                throw new RuntimeException(e);
            }
        }
        Arrays.stream(MigrationTaskStatus.values())
                .filter(status -> !result.containsKey(status))
                .forEach(status -> result.put(status, 0));
        return result;
    }

    private Map<Integer, Long> getTimings(List<Integer> percentiles, long periodMs) {
        ResultSet resultSet = executeQuery(connection -> connection.prepareStatement(
                "SELECT " +
                        percentiles
                                .stream()
                                .map(perc -> String.format(
                                        "(percentile_disc(0.%02d) within group (order by duration asc)) as prc%02d",
                                        perc,
                                        perc
                                )).collect(Collectors.joining(",")) +
                        " FROM " +
                        " (SELECT (lastupdate - timestamp) as duration " +
                        " FROM migrations " +
                        " WHERE status = 'finished' " +
                        " AND timestamp > " + (System.currentTimeMillis() - periodMs) +
                        " ORDER BY timestamp DESC) as durations;"
        ));
        if (!nextSafe(resultSet)) {
            return Map.of();
        }
        return percentiles.stream().collect(Collectors.toMap(
                perc -> perc,
                perc -> getLongSafe(resultSet, String.format("prc%02d", perc))
        ));
    }

    private long getLongSafe(ResultSet resultSet, String key) {
        try {
            return resultSet.getLong(key);
        } catch (SQLException e) {
            logger.log(Level.WARNING, "unable to get long: ", e);
            throw new RuntimeException(e);
        }
    }

    public void insert(List<MigrationTask> entries) {
        if (entries.isEmpty()) {
            return;
        }
        execute(connection -> {
            StringBuilder sb = new StringBuilder();
            sb.append("INSERT INTO migrations " +
                    "(timestamp, uid, email, action, status, requestId, optIn, types, lastUpdate, retryCount) " +
                    "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
            for (int i = 1; i < entries.size(); i++) {
                sb.append(",(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
            }
            sb.append(";");
            PreparedStatement statement = connection.prepareStatement(sb.toString());
            for (int i = 0; i < entries.size(); i++) {
                MigrationTask entry = entries.get(i);
                statement.setLong(i * 10 + 1, entry.getTimestamp());
                statement.setLong(i * 10 + 2, entry.getUid());
                statement.setString(i * 10 + 3, entry.getEmail());
                statement.setString(i * 10 + 4, entry.getAction().value());
                statement.setString(i * 10 + 5, entry.getStatus().value());
                statement.setString(i * 10 + 6, entry.getRequestId());
                statement.setBoolean(i * 10 + 7, entry.isOptIn());
                statement.setString(i * 10 + 8, String.join(",", entry.getTypes()));
                statement.setLong(i * 10 + 9, entry.getLastUpdate());
                statement.setInt(i * 10 + 10, entry.getRetryCount());
            }
            return statement;
        });
    }

    private MigrationTask getEntryFromResultSet(ResultSet resultSet) {
        try {
            return new MigrationTask(
                    resultSet.getLong("timestamp"),
                    resultSet.getLong("uid"),
                    resultSet.getString("email"),
                    resultSet.getString("action"),
                    resultSet.getString("status"),
                    resultSet.getString("requestId"),
                    resultSet.getBoolean("optIn"),
                    resultSet.getString("types"),
                    resultSet.getLong("lastUpdate"),
                    resultSet.getInt("retryCount")
            );
        } catch (SQLException e) {
            logger.log(Level.WARNING, "Error parsing entry: ", e);
            throw new RuntimeException(e);
        }
    }

    public List<MigrationTask> findByQuery(StatementSupplier statementSupplier) {
        ResultSet resultSet = executeQuery(statementSupplier);
        List<MigrationTask> result = new ArrayList<>();
        while (nextSafe(resultSet)) {
            result.add(getEntryFromResultSet(resultSet));
        }
        return result;
    }

    private boolean nextSafe(ResultSet resultSet) {
        try {
            return resultSet.next();
        } catch (SQLException e) {
            logger.log(Level.WARNING, "Can't get next item: ", e);
            throw new RuntimeException(e);
        }
    }

    private ResultSet executeQuery(StatementSupplier statementSupplier) {
        try (Connection connection = connectionPool.getConnection(logger)) {
            return statementSupplier.buildStatement(connection).executeQuery();
        } catch (SQLException e) {
            logger.log(Level.WARNING, "Can't execute query: ", e);
            throw new RuntimeException(e);
        }
    }

    private void execute(StatementSupplier statementSupplier) {
        try (Connection connection = connectionPool.getConnection(logger)) {
            statementSupplier.buildStatement(connection).execute();
        } catch (SQLException e) {
            logger.log(Level.WARNING, "Can't execute: ", e);
            throw new RuntimeException(e);
        }
    }

    private interface StatementSupplier  {
        PreparedStatement buildStatement(Connection connection) throws SQLException;
    }

    public static final class Stat {
        private final Map<MigrationTaskStatus, Integer> taskCounts;
        private final Map<Integer, Long> timings;

        public Stat(Map<MigrationTaskStatus, Integer> taskCounts, Map<Integer, Long> timings) {
            this.taskCounts = taskCounts;
            this.timings = timings;
        }

        public Map<MigrationTaskStatus, Integer> getTaskCounts() {
            return taskCounts;
        }

        public Map<Integer, Long> getTimings() {
            return timings;
        }
    }
}
