package ru.yandex.direct.binlogbroker.ytbootstrap.components;

import java.sql.Connection;
import java.sql.SQLException;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.collect.Iterators;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.binlog.model.BinlogEvent;
import ru.yandex.direct.binlog.model.CreateOrModifyColumn;
import ru.yandex.direct.binlog.model.CreateTable;
import ru.yandex.direct.binlog.model.Operation;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapper;
import ru.yandex.direct.mysql.schema.ColumnSchema;
import ru.yandex.direct.mysql.schema.KeyColumn;
import ru.yandex.direct.mysql.schema.TableSchema;
import ru.yandex.direct.utils.Checked;
import ru.yandex.direct.utils.CommonUtils;
import ru.yandex.direct.utils.Counter;
import ru.yandex.direct.utils.Interrupts.InterruptibleConsumer;

import static ru.yandex.direct.binlogbroker.mysql.MysqlUtil.mysqlDataTypeToColumnType;

@ParametersAreNonnullByDefault
public class TableYtBootstrap {

    private static final Logger logger = LoggerFactory.getLogger(TableYtBootstrap.class);

    private final Function<String, DatabaseWrapper> databaseWrapperProvider;
    private final String source;
    private final String tableName;
    private final int batchRows;
    private final int batchEvents;
    private Collection<Map.Entry<String, Object>> fetchStart = Collections.emptyList();
    private Collection<Collection<Map.Entry<String, Object>>> fetchFirst = Collections.emptyList();
    private int fetchFirstBatchRows = 0;
    private int fetchFirstBatchEvents = 0;

    private final Supplier<TableSchema> tableSchemaSupplier;
    private final Supplier<BinlogEvent> eventTemplateSupplier;

    private final long maxFetchRows;
    private long rowCount = 0L;

    public TableYtBootstrap(Function<String, DatabaseWrapper> databaseWrapperProvider, String source, String tableName,
                            int batchRows, int batchEvents, long maxFetchRows) {
        this.databaseWrapperProvider = databaseWrapperProvider;
        this.source = source;
        this.tableName = tableName;
        this.batchRows = batchRows;
        this.batchEvents = batchEvents;
        this.maxFetchRows = maxFetchRows;

        this.tableSchemaSupplier = CommonUtils.memoizeLock(Checked.supplier(this::getTableSchema));
        this.eventTemplateSupplier = CommonUtils.memoizeLock(this::getEventTemplate);
    }


    /**
     * Создает событие создания таблицы
     */
    public BinlogEvent createTableEvent() {
        logger.info("Building schema event for table {}", tableName);
        return createTableEvent(tableSchemaSupplier.get(), eventTemplateSupplier.get());
    }

    public Collection<Map.Entry<String, Object>> getFetchStart() {
        return fetchStart;
    }

    /**
     * Устанавливает primary key, после которого начинать чтение таблицы
     */
    public void setFetchStart(Collection<Map.Entry<String, Object>> primaryKey) {
        this.fetchStart = primaryKey;
    }

    /**
     * Устанавливает дополнительные пачки данных, которые надо прочитать вперед всего
     */
    public void setFetchFirst(Collection<Collection<Map.Entry<String, Object>>> primaryKeys, int batchRows,
                              int batchEvents) {
        this.fetchFirst = primaryKeys;
        this.fetchFirstBatchEvents = batchEvents;
        this.fetchFirstBatchRows = batchRows;
    }

    /**
     * Читает данные из таблицы, сначала все пачки fetchFirst, затем все подряд, начиная с fetchStart
     * трансформирует их в {@link ru.yandex.direct.binlog.model.BinlogEvent}'ы и отдает в eventConsumer.
     */
    public void fetchData(InterruptibleConsumer<List<BinlogEvent>> eventConsumer) throws InterruptedException {
        logger.info("Start fetching data for table {}", tableName);
        final TableSchema tableSchema = tableSchemaSupplier.get();
        final DatabaseWrapper wrapper = databaseWrapperProvider.apply(source);
        final BinlogEvent eventTemplate = eventTemplateSupplier.get();

        Counter queryIndex = new Counter(0);
        // fetch first
        for (Collection<Map.Entry<String, Object>> primaryKey : fetchFirst) {
            final Iterator<List<BinlogEvent.Row>> pager =
                    new TablePager(tableName, tableSchema, fetchFirstBatchRows, wrapper, primaryKey, true)
                            .iterator();
            eventConsumer.accept(getBatch(pager, fetchFirstBatchEvents, eventTemplate, queryIndex));
        }
        // fetch rest
        final Iterator<List<BinlogEvent.Row>> pager =
                new TablePager(tableName, tableSchema, batchRows, wrapper, fetchStart, false)
                        .iterator();
        List<BinlogEvent> batch;
        while (!(batch = getBatch(pager, batchEvents, eventTemplate, queryIndex)).isEmpty()) {
            eventConsumer.accept(batch);
            logger.debug("Supplied queries #{}+{} to event consumer for table {}",
                    batch.get(0).getQueryIndex(), batch.size(), tableName);
        }
        logger.info("Finished fetching data for table {}", tableName);
    }

    private TableSchema getTableSchema() throws SQLException {
        DatabaseWrapper databaseWrapper = databaseWrapperProvider.apply(source);
        try (Connection connection = databaseWrapper.getDataSource().getConnection()) {
            return TableSchema.dump(connection, tableName);
        }
    }

    private BinlogEvent getEventTemplate() {
        final DatabaseWrapper databaseWrapper = databaseWrapperProvider.apply(source);
        return new BinlogEvent()
                .withServerUuid(getServerUuid(databaseWrapper))
                .withTransactionId(0L)
                .withSource(source)
                .withUtcTimestamp(LocalDateTime.now())
                .withQueryIndex(0)
                .withDb(databaseWrapper.getDbname())
                .withTable(tableName)
                .withOperation(Operation.INSERT);
    }

    private BinlogEvent createTableEvent(TableSchema schema, BinlogEvent template) {
        if (!schema.getPrimaryKey().isPresent()) {
            throw new IllegalArgumentException(
                    "Input table " + tableName + " does not have primary key");
        }

        final ArrayList<CreateOrModifyColumn> columns = new ArrayList<>();
        for (ColumnSchema columnSchema : schema.getColumns()) {
            columns.add(
                    new CreateOrModifyColumn()
                            .withColumnName(columnSchema.getName())
                            .withColumnType(mysqlDataTypeToColumnType(columnSchema))
                            .withNullable(columnSchema.isNullable())
                            .withDefaultValue(columnSchema.getDefaultValue())
            );
        }

        final List<String> primaryKey = new ArrayList<>();
        for (KeyColumn keyColumn : schema.getPrimaryKey().get().getColumns()) {
            primaryKey.add(keyColumn.getName());
        }

        return BinlogEvent.fromTemplate(template)
                .withOperation(Operation.SCHEMA)
                .withAddedSchemaChanges(
                        new CreateTable()
                                .withColumns(columns)
                                .withPrimaryKey(primaryKey)
                )
                .validate();
    }

    private String getServerUuid(DatabaseWrapper databaseWrapper) {
        return databaseWrapper.query("select @@server_uuid", (rs, i) -> rs.getString(1)).get(0);
    }

    private List<BinlogEvent> getBatch(Iterator<List<BinlogEvent.Row>> source, int batchEvents, BinlogEvent template,
                                       Counter queryIndex) {
        final ArrayList<BinlogEvent> batch = new ArrayList<>(batchEvents);
        if (maxFetchRows < 0L || rowCount < maxFetchRows) {
            source = Iterators.limit(source, batchEvents);
            while (source.hasNext()) {
                batch.add(BinlogEvent.fromTemplate(template)
                        .withQueryIndex(queryIndex.next())
                        .withRows(source.next())
                        .validate());
            }
            batch.forEach(e -> rowCount += e.getRows().size());
        }
        return batch;
    }
}
