package ru.yandex.direct.mysql.ytsync.synchronizator.tableprocessors;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.mysql.BinlogEventData;
import ru.yandex.direct.mysql.MySQLSimpleRow;
import ru.yandex.direct.mysql.MySQLUpdateRows;
import ru.yandex.direct.mysql.ytsync.common.row.FlatRow;
import ru.yandex.direct.mysql.ytsync.synchronizator.tables.Table;
import ru.yandex.direct.mysql.ytsync.synchronizator.tables.TableWriteSink;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.yt.ytclient.tables.ColumnSchema;
import ru.yandex.yt.ytclient.tables.TableSchema;

@ParametersAreNonnullByDefault
public class TaskBasedTableProcessor implements TableProcessor {
    private static final Logger logger = LoggerFactory.getLogger(TaskBasedTableProcessor.class);

    // если строк в апдейте больше этого числа - обрабатываем в через parallelStream
    private static final int ROWS_NUMBER_BORDER_FOR_PARALLEL =
            Integer.parseInt(System.getProperty("rows_number_border_for_parallel", "1000"));

    private final List<Integer> indexKeyColumnNums;

    private final Set<String> databases;
    private final Set<String> mainTables;
    private final Map<String, Predicate<MySQLSimpleRow>> skipInsertTables;
    private final Set<String> allTables;
    private final SyncFlatRowCreatorBase syncFlatRowCreator;
    private final Table sink;
    private final Map<FlatRow, String> rowOwners = new HashMap<>();

    public TaskBasedTableProcessor(
            Set<String> databases, Set<String> mainTables, Set<String> allTables,
            SyncFlatRowCreatorBase syncFlatRowCreator,
            TableSchema indexSchema, Table sink) {
        this(databases, mainTables, Collections.emptyMap(), allTables, syncFlatRowCreator,
                indexSchema, sink);
    }

    public TaskBasedTableProcessor(
            Set<String> databases, Set<String> mainTables, Map<String, Predicate<MySQLSimpleRow>> skipInsertTables,
            Set<String> allTables,
            SyncFlatRowCreatorBase syncFlatRowCreator,
            TableSchema indexSchema, Table sink) {
        this.databases = databases;
        this.mainTables = mainTables;
        this.skipInsertTables = skipInsertTables;
        this.allTables = allTables;
        this.syncFlatRowCreator = syncFlatRowCreator;
        this.sink = sink;

        indexKeyColumnNums = prepareIndexes(indexSchema, sink);
    }

    private static List<Integer> prepareIndexes(@Nullable TableSchema indexSchema, TableWriteSink sink) {
        if (indexSchema == null) {
            return Collections.emptyList();
        }
        return indexSchema.getColumns().stream()
                .filter(c -> c.getSortOrder() != null)
                .map(ColumnSchema::getName)
                .map(k -> sink.getWriteSchema().findColumn(k))
                .peek(i -> {
                    if (i == -1) {
                        throw new IllegalArgumentException("Cannot find column");
                    }
                })
                .collect(Collectors.toList());
    }

    private FlatRow makeIndexKey(FlatRow row) {
        FlatRow indexKey = new FlatRow(indexKeyColumnNums.size());

        for (int i = 0; i < indexKeyColumnNums.size(); i++) {
            YTreeNode value = row.get(indexKeyColumnNums.get(i));
            if (value == null) {
                return null;
            }
            indexKey.set(i, value);
        }
        return indexKey;
    }

    @Override
    public void flush() {
        rowOwners.clear();
    }

    @Override
    public Table getMainTable() {
        return sink;
    }

    @Override
    public Set<TableWriteSink> getSinks() {
        return Collections.singleton(sink);
    }

    private boolean doNotNeedProcessing(BinlogEventData.BinlogRowsEventData eventData) {
        return !databases.contains(eventData.getTableMap().getDatabase()) ||
                !allTables.contains(eventData.getTableMap().getTable());
    }

    private FlatRow makeScrubbedKey(FlatRow row) {
        FlatRow key = row.copy();
        for (int i = sink.getWriteSchema().getKeyColumnsCount(); i < key.size(); ++i) {
            key.set(i, null);
        }
        return key;
    }

    @Override
    public int handleDelete(BinlogEventData.Delete deleted, String dbName) {
        String table = deleted.getTableMap().getTable();
        if (doNotNeedProcessing(deleted) || !mainTables.contains(table)) {
            return 0;
        }

        int count = 0;
        for (MySQLSimpleRow row : deleted.getRows()) {

            FlatRow flatRow = syncFlatRowCreator.createFrom(dbName, table, row, false);

            FlatRow indexKey = makeIndexKey(flatRow);
            if (indexKey == null) {
                logger.info("Cannot process incomplete delete: {}", flatRow);
                continue;
            }
            String currentOwner = rowOwners.get(indexKey);
            if (currentOwner == null || currentOwner.equals(table)) {
                // Применяем DELETE только если у строки неизвестен владелец (т.е. INSERT был далеко за
                // пределами текущей транзакции) или идёт удаление из той же таблицы.
                rowOwners.remove(indexKey);
                count += sink.addDelete(flatRow);
            }
        }
        return count;
    }

    @Override
    public int handleInsert(BinlogEventData.Insert inserted, String dbName) {
        if (doNotNeedProcessing(inserted)) {
            return 0;
        }

        String table = inserted.getTableMap().getTable();

        int count = 0;
        if (mainTables.contains(table)) {
            count = handleRealInsert(inserted, dbName, table, count);
        } else {
            count = handleUpdateInsert(inserted, dbName, table, count);
        }
        return count;
    }

    private int handleUpdateInsert(BinlogEventData.Insert inserted, String dbName, String table, int count) {
        for (MySQLSimpleRow row : inserted.getRows()) {
            FlatRow flatRow = syncFlatRowCreator.createFrom(dbName, table, row, true);
            FlatRow indexKey = makeIndexKey(flatRow);
            if (indexKey == null) {
                logger.info("Cannot process incomplete update: {}", flatRow);
                continue;
            }
            count += sink.addUpdate(flatRow);
        }
        return count;
    }

    private int handleRealInsert(BinlogEventData.Insert inserted, String dbName, String table, int count) {
        for (MySQLSimpleRow row : inserted.getRows()) {
            if (skipInsertTables.containsKey(table) && skipInsertTables.get(table).test(row)) {
                logger.info("Skip insert for table {} by predicate", table);
                continue;
            }

            FlatRow flatRow = syncFlatRowCreator.createFrom(dbName, table, row, true);
            FlatRow indexKey = makeIndexKey(flatRow);
            if (indexKey == null) {
                logger.info("Cannot process incomplete insert: {}", flatRow);
                continue;
            }
            rowOwners.put(indexKey, table);
            count += sink.addInsert(flatRow);
        }
        return count;
    }

    @Override
    public int handleUpdate(BinlogEventData.Update updated, String dbName) {
        if (doNotNeedProcessing(updated)) {
            return 0;
        }

        var table = updated.getTableMap().getTable();
        var rows = updated.getRows();
        if (rows.size() < ROWS_NUMBER_BORDER_FOR_PARALLEL) {
            return handleUpdateSingleThread(dbName, table, rows);
        } else {
            return handleUpdateParallel(dbName, table, rows);
        }
    }

    private int handleUpdateSingleThread(String dbName, String table, MySQLUpdateRows rows) {
        int count = 0;
        for (var row : rows) {
            FlatRow before = syncFlatRowCreator.createFrom(dbName, table, row.getBeforeUpdate(), false);
            FlatRow after = syncFlatRowCreator.createFrom(dbName, table, row.getAfterUpdate(), false);
            count += sink.addUpdate(before, after);
        }
        return count;
    }

    private int handleUpdateParallel(String dbName, String table, MySQLUpdateRows rows) {
        var flatRows = rows.parallelStream()
                .map(row -> {
                    FlatRow before = syncFlatRowCreator.createFrom(dbName, table, row.getBeforeUpdate(), false);
                    FlatRow after = syncFlatRowCreator.createFrom(dbName, table, row.getAfterUpdate(), false);
                    return new RowBeforeAfter(before, after);
                })
                .collect(Collectors.toList());

        int count = 0;
        for (var row : flatRows) {
            count += sink.addUpdate(row.before, row.after);
        }
        return count;
    }

    private static class RowBeforeAfter {
        private final FlatRow before;
        private final FlatRow after;

        public RowBeforeAfter(FlatRow before, FlatRow after) {
            this.before = before;
            this.after = after;
        }
    }
}
