package ru.yandex.chemodan.app.djfs.core.db.pg;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.Iterator;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import javax.sql.DataSource;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.SneakyThrows;
import org.jetbrains.annotations.NotNull;
import org.postgresql.jdbc.PgConnection;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;

import ru.yandex.bolts.collection.ListF;
import ru.yandex.chemodan.app.djfs.core.util.StreamUtils;
import ru.yandex.commune.util.RetryUtils;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.spring.jdbc.ArgPreparedStatementSetter;

public class PgCursorUtils {
    private static final Logger logger = LoggerFactory.getLogger(PgCursorUtils.class);

    @SneakyThrows
    public static <T> Stream<ListF<T>> queryWithCursorAsBatches(DataSource dataSource, int batchSize, String query,
            RowMapper<T> rowMapper,
            Object... args)
    {
        return queryWithCursorAsBatches(dataSource, batchSize, query, rowMapper, new ArgPreparedStatementSetter(args));
    }


    @SneakyThrows
    public static <T> Stream<ListF<T>> queryWithCursorAsBatches(DataSource dataSource, int batchSize, String query,
            RowMapper<T> rowMapper, PreparedStatementSetter preparedStatementSetter)
    {
        return StreamUtils.batches(
                queryWithCursor(dataSource, batchSize, query, rowMapper, preparedStatementSetter),
                batchSize
        );
    }

    @NotNull
    public static <T> Stream<T> queryWithCursor(DataSource dataSource, int batchSize, String query,
            RowMapper<T> rowMapper, PreparedStatementSetter pss)
    {
        return StreamSupport.stream(() -> CursorSpliterator.cons(batchSize, rowMapper, dataSource, query, pss),
                CursorSpliterator.CHARACTERISTICS, false);
    }

    public static <T, R> R queryWithCursor(DataSource dataSource, int batchSize, String query,
            RowMapper<T> rowMapper, PreparedStatementSetter pss, Function<Iterator<T>, R> block)
    {
        try (CursorSpliterator<T> spliterator = CursorSpliterator.cons(batchSize, rowMapper, dataSource, query, pss)) {
            Stream<T> stream = StreamSupport.stream(() -> spliterator, CursorSpliterator.CHARACTERISTICS, false);
            return block.apply(stream.iterator());
        }
    }

    /**
     * For work of cursor needs autocommit = false and set fetchSize. See https://jdbc.postgresql.org/documentation/head/query.html
     * <p>
     * There is direct approach with connect. Minus - no trace, no ArgPreparedStatementSetter
     * Alternative way: create wrap of DataSource and use JdbcTemplate3
     * <p>
     * CAUTION!!! work only with preparedStatementCacheQueries=0 param in jdbc connection (pgBouncer)
     */
    @RequiredArgsConstructor
    private static class CursorSpliterator<T> implements Spliterator<T>, AutoCloseable {
        private static final int CHARACTERISTICS = Spliterator.ORDERED & Spliterator.DISTINCT & Spliterator.NONNULL;
        private final RowMapper<T> rowMapper;
        private final ResultSetWithConnection resultSetWithConnection;
        private int totalProcessedRows = 0;
        private boolean closed = false;

        public static <T> CursorSpliterator<T> cons(int batchSize, RowMapper<T> rowMapper,
                DataSource dataSource, String query,
                PreparedStatementSetter preparedStatementSetter)
        {
            return new CursorSpliterator<>(rowMapper, RetryUtils.retry(logger, 3, () ->
                    openResultSet(dataSource, query, preparedStatementSetter, batchSize)
            ));
        }

        @Override
        public void close() {
            if (!closed) {
                resultSetWithConnection.close();
                closed = true;
            }
        }

        @SneakyThrows
        @Override
        public boolean tryAdvance(Consumer<? super T> consumer) {
            if (closed) {
                return false;
            }
            try {
                if (resultSetWithConnection.getResultSet().next()) {
                    T row = rowMapper.mapRow(resultSetWithConnection.getResultSet(), totalProcessedRows);

                    ++totalProcessedRows;

                    consumer.accept(row);
                    return true;
                }
                close();
                return false;
            } catch (Throwable e) {
                close();
                throw e;
            }
        }

        @Override
        public Spliterator<T> trySplit() {
            return null;
        }

        @Override
        public long estimateSize() {
            return Long.MAX_VALUE;
        }

        @Override
        public int characteristics() {
            return CHARACTERISTICS;
        }
    }

    @SneakyThrows
    private static ResultSetWithConnection openResultSet(
            DataSource dataSource, String query, PreparedStatementSetter preparedStatementSetter, int batchSize
    )
    {
        ResultSetWithConnection result = null;
        try {
            Connection connection = dataSource.getConnection();
            boolean autoCommitBefore = connection.getAutoCommit();
            connection.unwrap(PgConnection.class).setAutoCommit(false);
            result = new ResultSetWithConnection(connection, autoCommitBefore);

            PreparedStatement preparedStatement = connection.prepareStatement(query);
            preparedStatementSetter.setValues(preparedStatement);
            preparedStatement.setFetchSize(batchSize);

            result.setResultSet(preparedStatement.executeQuery());
            return result;
        } catch (Throwable e) {
            if (result != null) {
                result.close();
            }
            throw e;
        }
    }

    @RequiredArgsConstructor
    private static class ResultSetWithConnection {
        private final Connection connection;
        private final boolean autoCommitBefore;
        @Setter
        @Getter
        private ResultSet resultSet = null;

        @SneakyThrows
        public void close() {
            try {
                if (resultSet != null) {
                    resultSet.close();
                }
            } finally {
                connection.unwrap(PgConnection.class).setAutoCommit(autoCommitBefore);
                connection.close();
            }
        }
    }
}
