package ru.yandex.direct.binlogbroker.ytbootstrap.components;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.LocalDateTime;
import java.time.ZonedDateTime;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.stream.Collectors;

import javax.annotation.ParametersAreNonnullByDefault;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;

import ru.yandex.direct.binlog.model.BinlogEvent;
import ru.yandex.direct.binlog.model.ColumnType;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapper;
import ru.yandex.direct.mysql.schema.ColumnSchema;
import ru.yandex.direct.mysql.schema.KeyColumn;
import ru.yandex.direct.mysql.schema.TableSchema;
import ru.yandex.direct.utils.MySQLQuote;

import static ru.yandex.direct.binlog.model.ColumnType.TIMESTAMP_ZONE_ID;

@ParametersAreNonnullByDefault
class TablePager implements Iterable<List<BinlogEvent.Row>> {

    private static final Logger logger = LoggerFactory.getLogger(TablePager.class);

    private static final Map<String, Object> EMPTY = Map.of();

    private final List<String> columns;
    private final List<String> primaryKey;
    private final String firstPageSql;
    private final String subsequentPagesSql;

    private final DatabaseWrapper databaseWrapper;
    private final Collection<Map.Entry<String, Object>> readStartPosition;

    TablePager(String tableName, TableSchema tableSchema, int pageSize,
               DatabaseWrapper databaseWrapper,
               Collection<Map.Entry<String, Object>> readStartPosition, boolean inclusive
    ) {
        this.databaseWrapper = databaseWrapper;

        columns = tableSchema.getColumns().stream()
                .map(ColumnSchema::getName)
                .collect(Collectors.toList());
        primaryKey = tableSchema.getPrimaryKey()
                .orElseThrow(() -> new IllegalStateException("Table " + tableName + " does not have primary key"))
                .getColumns()
                .stream()
                .map(KeyColumn::getName)
                .collect(Collectors.toList());

        final String selectColumnsFromTable = "select " + commaSeparatedQuotedNames(columns) +
                " from " + MySQLQuote.quoteName(tableName);

        final String orderByAndLimit = " order by " + commaSeparatedQuotedNames(primaryKey) + " limit 0, " + pageSize;

        final String firstPageWherePart = readStartPosition.isEmpty() ? "" :
                whereKeyGreaterExpression(primaryKey, inclusive);

        firstPageSql = selectColumnsFromTable + firstPageWherePart + orderByAndLimit;
        subsequentPagesSql = selectColumnsFromTable + whereKeyGreaterExpression(primaryKey, false) + orderByAndLimit;
        this.readStartPosition = readStartPosition;
    }

    private static String commaSeparatedQuotedNames(List<String> rawNames) {
        return rawNames.stream().map(MySQLQuote::quoteName).collect(Collectors.joining(","));
    }

    private static String whereKeyGreaterExpression(List<String> columns, boolean allowEqual) {
        StringBuilder whereBuilder = new StringBuilder(" where ");
        // строим выражение вида k1>v1 or (k1=v1 and (k2>v2 or (k2=v2 and k3>v3)))
        // более простое выражение (k1,k2,k3) > (v1,v2,v3) не используем,
        // потому что MySQL не умеет его правильно использовать с индексом
        // https://use-the-index-luke.com/sql/partial-results/fetch-next-page
        int lastIndex = columns.size() - 1;
        for (int i = 0; i < lastIndex; i++) {
            String column = columns.get(i);
            String quotedColumnName = MySQLQuote.quoteName(column);
            whereBuilder.append(quotedColumnName).append('>').append(":{").append(column).append('}');
            whereBuilder.append(" or (").append(quotedColumnName).append('=').append(":{").append(column).append("} " +
                    "and (");
        }
        String lastColumn = columns.get(lastIndex);
        whereBuilder.append(MySQLQuote.quoteName(lastColumn))
                .append((allowEqual ? ">=" : ">"))
                .append(":{")
                .append(lastColumn)
                .append('}');

        for (int i = 1; i <= lastIndex; i++) {
            whereBuilder.append("))");
        }
        return whereBuilder.toString();
    }

    @Override
    public Iterator<List<BinlogEvent.Row>> iterator() {
        return new PageIterator(databaseWrapper, columns, primaryKey, firstPageSql, subsequentPagesSql,
                readStartPosition);
    }

    private static class PageIterator implements Iterator<List<BinlogEvent.Row>> {

        private final DatabaseWrapper databaseWrapper;
        private final List<String> columns;
        private final List<String> primaryKey;
        private final String sql;
        private int rowIndex = 0;

        private List<BinlogEvent.Row> next;

        PageIterator(DatabaseWrapper databaseWrapper, List<String> columns,
                     List<String> primaryKey, String firstPageSql, String subsequentPageSql,
                     Collection<Map.Entry<String, Object>> initialArgs) {
            this.databaseWrapper = databaseWrapper;
            this.columns = columns;
            this.primaryKey = primaryKey;
            this.sql = subsequentPageSql;

            final MapSqlParameterSource args = new MapSqlParameterSource();
            for (Map.Entry<String, Object> kv : initialArgs) {
                args.addValue(kv.getKey(), kv.getValue());
            }
            next = databaseWrapper.query(firstPageSql, args, this::mapRow);
        }

        @Override
        public boolean hasNext() {
            return !next.isEmpty();
        }

        @Override
        public List<BinlogEvent.Row> next() {
            if (!hasNext()) {
                throw new NoSuchElementException();
            }
            List<BinlogEvent.Row> result = next;
            rowIndex = 0;
            next = databaseWrapper.query(sql, getArgs(next), this::mapRow);
            return result;
        }

        private MapSqlParameterSource getArgs(List<BinlogEvent.Row> rows) {
            assert !rows.isEmpty();
            MapSqlParameterSource args = new MapSqlParameterSource();
            final BinlogEvent.Row lastRow = rows.get(rows.size() - 1);
            for (Map.Entry<String, Object> kv : lastRow.getPrimaryKey().entrySet()) {
                final Object value = kv.getValue();
                if (value instanceof LocalDateTime) {
                    args.addValue(kv.getKey(), ZonedDateTime.of((LocalDateTime) value, TIMESTAMP_ZONE_ID));
                } else {
                    args.addValue(kv.getKey(), value);
                }
            }
            return args;
        }

        private BinlogEvent.Row mapRow(ResultSet resultSet, int rowNumber) throws SQLException {
            Map<String, Object> pk = new HashMap<>(primaryKey.size());
            Map<String, Object> after = new HashMap<>(columns.size() - primaryKey.size());
            int i = 1;
            for (String column : columns) {
                final ColumnType.Normalized normalized = ColumnType.normalize(resultSet.getObject(i));
                final Object value = normalized == null ? null : normalized.getObject();
                (primaryKey.contains(column) ? pk : after).put(column, value);
                i++;
            }

            return new BinlogEvent.Row()
                    .withPrimaryKey(pk)
                    .withAfter(after)
                    .withBefore(EMPTY)
                    .withRowIndex(rowIndex++)
                    .validate();
        }
    }
}
