package ru.yandex.direct.mysql.ytsync.common.components;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import one.util.streamex.StreamEx;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.function.Function;
import ru.yandex.direct.mysql.MySQLBinlogState;
import ru.yandex.direct.mysql.schema.ServerSchema;
import ru.yandex.direct.mysql.ytsync.common.compatibility.YtSupport;
import ru.yandex.direct.mysql.ytsync.common.keys.PivotKeys;
import ru.yandex.direct.mysql.ytsync.common.row.FlatRow;
import ru.yandex.direct.mysql.ytsync.common.tables.TableWriteOperations;
import ru.yandex.direct.mysql.ytsync.common.tables.TableWriteSnapshot;
import ru.yandex.direct.mysql.ytsync.common.util.YtSyncCommonUtil;
import ru.yandex.direct.ytwrapper.model.attributes.OptimizeForAttr;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTree;
import ru.yandex.inside.yt.kosher.ytree.YTreeNode;
import ru.yandex.yt.ytclient.tables.ColumnValueType;
import ru.yandex.yt.ytclient.tables.TableSchema;

/**
 * Абстрагирует работу с таблицей mysql-sync-states
 * <p>
 * ВНИМАНИЕ: не является потоко-безопасным!
 */
public class SyncStatesTable implements TableWriteOperations {
    private static final Logger logger = LoggerFactory.getLogger(SyncStatesTable.class);
    private static final ObjectMapper objectMapper = new ObjectMapper();

    private static final int COLUMN_DBNAME = 0;
    private static final int COLUMN_GTID_SET = 1;
    private static final int COLUMN_SERVER_SCHEMA = 2;
    private static final int COLUMN_LAST_TIMESTAMP = 3;
    private static final int COLUMN_IS_IMPORTED = 4;
    private static final int COLUMN_IMPORTED_TABLES = 5;

    public static final TableSchema SCHEMA = new TableSchema.Builder()
            .addKey("dbname", ColumnValueType.STRING)
            .addValue("gtid_set", ColumnValueType.STRING)
            .addValue("server_schema", ColumnValueType.STRING)
            .addValue("last_timestamp", ColumnValueType.INT64)
            .addValue("is_imported", ColumnValueType.BOOLEAN)
            .addValue("imported_tables", ColumnValueType.STRING)
            .build();

    private static final TableSchema LOOKUP_SCHEMA = SCHEMA.toLookup();

    public static final OptimizeForAttr OPTIMIZE_FOR = OptimizeForAttr.LOOKUP;

    public static final PivotKeys PIVOT_KEYS = PivotKeys.singlePartition();

    private final YtSupport ytSupport;
    private final String path;
    private final String medium;
    private final Consumer<Map<Integer, Long>> timeLagConsumer;
    private final Map<String, SyncState> upstreamStates = new HashMap<>();
    private final Map<String, SyncState> updatedStates = new HashMap<>();

    public static class SyncState {
        private final String dbName;
        private String gtidSet;
        private ServerSchema serverSchema;
        private Long lastTimestamp;
        private Boolean isImported;
        private List<String> importedTables;

        public SyncState(String dbName) {
            this.dbName = dbName;
        }

        public SyncState copy() {
            SyncState copy = new SyncState(dbName);
            copy.gtidSet = gtidSet;
            copy.serverSchema = serverSchema;
            copy.lastTimestamp = lastTimestamp;
            copy.isImported = isImported;
            copy.importedTables = importedTables;
            return copy;
        }

        public String getDbName() {
            return dbName;
        }

        public String getGtidSet() {
            return gtidSet;
        }

        public ServerSchema getServerSchema() {
            return serverSchema;
        }

        public Long getLastTimestamp() {
            return lastTimestamp;
        }

        public Boolean getImported() {
            return isImported;
        }

        public List<String> getImportedTables() {
            return importedTables;
        }
    }

    private static <U> U extract(YTreeNode node, Function<? super YTreeNode, U> fn) {
        if (node != null && !(node.isEntityNode())) {
            return fn.apply(node);
        }
        return null;
    }

    public SyncStatesTable(YtSupport ytSupport,
                           String path,
                           String medium,
                           Consumer<Map<Integer, Long>> timeLagConsumer) {
        this.ytSupport = ytSupport;
        this.path = path;
        this.medium = medium;
        this.timeLagConsumer = timeLagConsumer;
    }

    public String getPath() {
        return path;
    }

    /**
     * Проверяет существование правильной таблицы на yt и монтирует в случае необходимости
     */
    public void verifyAndMount() {
        YtSyncCommonUtil.verifyDynamicTable(ytSupport, getPath(), SCHEMA);

        // unfreeze в этом вызове намеренно не делаем, т.к. таблицу с состоянием никто и не должен фризить
        // поэтому ручной freeze этой таблицы потенциально можно использовать для остановки репликации
        YtSyncCommonUtil.makeTableMounted(ytSupport, getPath(), false);
    }

    /**
     * Создаёт и подготавливает таблицу на yt
     */
    public void prepareTable() {
        ytSupport.prepareDynamicTable(
                getPath(),
                SCHEMA,
                OPTIMIZE_FOR,
                PIVOT_KEYS,
                Cf.map("primary_medium", YTree.stringNode(medium))
        ).join(); // IGNORE-BAD-JOIN DIRECT-149116
    }

    /**
     * Декодирует плоскую строку в нашу структуру
     */
    private static Optional<SyncState> decodeRow(FlatRow row) {
        SyncState state = new SyncState(row.get(COLUMN_DBNAME).stringValue());
        state.gtidSet = extract(row.get(COLUMN_GTID_SET), YTreeNode::stringValue);
        try {
            String serverSchema = extract(row.get(COLUMN_SERVER_SCHEMA), YTreeNode::stringValue);
            if (serverSchema != null) {
                state.serverSchema = objectMapper.readValue(serverSchema, ServerSchema.class);
            }
        } catch (IOException e) {
            logger.warn("Failed to deserialize sync state for " + state.dbName, e);
            return Optional.empty();
        }
        state.lastTimestamp = extract(row.get(COLUMN_LAST_TIMESTAMP), YTreeNode::longValue);
        state.isImported = extract(row.get(COLUMN_IS_IMPORTED), YTreeNode::boolValue);
        String rawImportedTables = extract(row.get(COLUMN_IMPORTED_TABLES), YTreeNode::stringValue);
        if (rawImportedTables != null && !rawImportedTables.isEmpty()) {
            state.importedTables = Collections.unmodifiableList(Arrays.asList(rawImportedTables.split(",")));
        } else {
            state.importedTables = null;
        }
        return Optional.of(state);
    }

    /**
     * Возвращает строку из таблицы на yt
     */
    public Optional<SyncState> lookupRow(String dbName) {
        FlatRow key = new FlatRow(LOOKUP_SCHEMA.getColumnsCount());
        key.set(COLUMN_DBNAME, YTree.builder().value(dbName).build());
        List<FlatRow> resultRows = ytSupport.nullTransaction()
                .thenComposeAsync(
                        tx -> tx.lookupRows(getPath(), LOOKUP_SCHEMA, Collections.singletonList(key), SCHEMA),
                        ytSupport.executor())
                .join(); // IGNORE-BAD-JOIN DIRECT-149116
        if (resultRows.isEmpty()) {
            return Optional.empty();
        }
        if (resultRows.size() != 1) {
            throw new IllegalStateException("LookupRows returned " + resultRows.size() + " unexpected rows");
        }
        FlatRow row = resultRows.get(0);
        String rowDbName = row.get(COLUMN_DBNAME).stringValue();
        if (!dbName.equals(rowDbName)) {
            throw new IllegalStateException("LookupRows returned unexpected key: " + rowDbName);
        }
        return decodeRow(row);
    }

    /**
     * Возвращает закешированную строку, как она лежит на yt
     */
    private SyncState getUpstreamState(String dbName) {
        SyncState state = upstreamStates.get(dbName);
        if (state == null) {
            Optional<SyncState> upstreamState = lookupRow(dbName);
            if (upstreamState.isPresent()) {
                state = upstreamState.get();
                upstreamStates.put(dbName, state);
            }
        }
        return state;
    }

    /**
     * Возвращает текущее состояние с учётом обновлений
     *
     * @param dbName    имя базы (например ppc:1)
     * @param forUpdate true, если состояние будет обновляться
     */
    private SyncState getState(String dbName, boolean forUpdate) {
        SyncState state = updatedStates.get(dbName);
        if (state == null) {
            state = getUpstreamState(dbName);
            if (forUpdate) {
                if (state != null) {
                    state = state.copy();
                } else {
                    state = new SyncState(dbName);
                }
                updatedStates.put(dbName, state);
            }
        }
        return state;
    }

    public MySQLBinlogState getBinlogState(String dbName) {
        SyncState state = getState(dbName, false);
        if (state != null && state.serverSchema != null && state.gtidSet != null) {
            return new MySQLBinlogState(state.serverSchema, state.gtidSet);
        }
        return null;
    }

    public boolean isImported(String dbName) {
        SyncState state = getState(dbName, false);
        if (state != null && state.isImported != null) {
            return state.isImported;
        }
        return false;
    }

    public boolean isTableImported(String dbName, String tableName) {
        SyncState state = getState(dbName, false);
        if (state != null && state.importedTables != null) {
            return state.importedTables.contains(tableName);
        }
        return false;
    }

    public Long getLastTimestamp(String dbName) {
        SyncState state = getState(dbName, false);
        if (state != null) {
            return state.lastTimestamp;
        }
        return null;
    }

    public void setBinlogState(String dbName, MySQLBinlogState binlogState) {
        SyncState state = getState(dbName, true);
        state.gtidSet = binlogState.getGtidSet();
        state.serverSchema = binlogState.getServerSchema();
    }

    public void setState(String dbName, SyncState state) {
        updatedStates.put(dbName, state);
    }

    public void setImported(String dbName, boolean isImported) {
        SyncState state = getState(dbName, true);
        state.isImported = isImported;
    }

    public void setTableImported(String dbName, String tableName, boolean isImported) {
        SyncState state = getState(dbName, true);
        if (isImported) {
            if (state.importedTables != null && state.importedTables.contains(tableName)) {
                // таблица уже есть в списке импортированных
                return;
            }
        } else {
            if (state.importedTables == null || !state.importedTables.contains(tableName)) {
                // таблицы уже нет в списке импортированных
                return;
            }
        }
        int oldCount = state.importedTables != null ? state.importedTables.size() : 0;
        List<String> newTables = new ArrayList<>(oldCount + 1);
        if (state.importedTables != null) {
            newTables.addAll(state.importedTables);
        }
        if (isImported) {
            newTables.add(tableName);
        } else {
            newTables.remove(tableName);
        }
        state.importedTables = Collections.unmodifiableList(newTables);
    }

    public void setLastTimestamp(String dbName, long lastTimestamp) {
        SyncState state = getState(dbName, true);
        state.lastTimestamp = lastTimestamp;
    }

    /**
     * Подготавливает данные к записи в рамках транзакции tx
     */
    public CompletableFuture<List<TableWriteSnapshot>> prepare(YtSupport.Transaction tx) {
        List<FlatRow> updates = new ArrayList<>();
        for (SyncState newState : updatedStates.values()) {
            boolean changed = false;
            FlatRow row = new FlatRow(SCHEMA.getColumnsCount());
            row.set(COLUMN_DBNAME, YTree.builder().value(newState.dbName).build());
            SyncState upstreamState = upstreamStates.get(newState.dbName);
            try {
                if (upstreamState == null || !Objects.equals(newState.gtidSet, upstreamState.gtidSet)) {
                    if (newState.gtidSet != null) {
                        row.set(COLUMN_GTID_SET, YTree.builder().value(newState.gtidSet).build());
                    } else {
                        row.set(COLUMN_GTID_SET, YTree.builder().entity().build());
                    }
                    changed = true;
                }
                if (upstreamState == null || !Objects.equals(newState.serverSchema, upstreamState.serverSchema)) {
                    if (newState.serverSchema != null) {
                        row.set(COLUMN_SERVER_SCHEMA,
                                YTree.builder().value(objectMapper.writeValueAsString(newState.serverSchema)).build());
                    } else {
                        row.set(COLUMN_SERVER_SCHEMA, YTree.builder().entity().build());
                    }
                    changed = true;
                }
                if (upstreamState == null || !Objects.equals(newState.lastTimestamp, upstreamState.lastTimestamp)) {
                    if (newState.lastTimestamp != null) {
                        row.set(COLUMN_LAST_TIMESTAMP, YTree.builder().value(newState.lastTimestamp).build());
                    } else {
                        row.set(COLUMN_LAST_TIMESTAMP, YTree.builder().entity().build());
                    }
                    changed = true;
                }
                if (upstreamState == null || !Objects.equals(newState.isImported, upstreamState.isImported)) {
                    if (newState.isImported != null) {
                        row.set(COLUMN_IS_IMPORTED, YTree.builder().value(newState.isImported).build());
                    } else {
                        row.set(COLUMN_IS_IMPORTED, YTree.builder().entity().build());
                    }
                    changed = true;
                }
                if (upstreamState == null || !Objects.equals(newState.importedTables, upstreamState.importedTables)) {
                    if (newState.importedTables != null) {
                        row.set(COLUMN_IMPORTED_TABLES,
                                YTree.builder().value(String.join(",", newState.importedTables)).build());
                    } else {
                        row.set(COLUMN_IMPORTED_TABLES, YTree.builder().entity().build());
                    }
                    changed = true;
                }
            } catch (JsonProcessingException e) {
                logger.warn("Failed to serialize sync state for " + newState.dbName, e);
                continue;
            }
            if (changed) {
                updates.add(row);
            }
        }
        if (updates.isEmpty()) {
            return CompletableFuture.completedFuture(Collections.emptyList());
        }
        return CompletableFuture.completedFuture(Collections.singletonList(new TableWriteSnapshot(
                getPath(),
                SCHEMA,
                Collections.emptyList(),
                updates,
                Collections.emptyList())));
    }

    /**
     * Помечает все накопленные обновления как закомиченные
     */
    public void committed() {
        if (!updatedStates.isEmpty()) {
            if (updatedStates.size() == 1) {
                String gtidSets = updatedStates.values().stream()
                        .map(state -> String.format("%s:%s", state.getDbName(), state.getGtidSet()))
                        .map(str -> str.replaceAll("\\n", ""))
                        .collect(Collectors.joining(";"));
                logger.info("Committed {} sync states: {}", updatedStates.size(), gtidSets);
            } else {
                logger.info("Committed {} sync states", updatedStates.size());
            }
            upstreamStates.putAll(updatedStates);
            sendShardsTimeLagToMonitoringSystem();
            updatedStates.clear();
        }
    }

    private void sendShardsTimeLagToMonitoringSystem() {
        Map<Integer, Long> shardsLastTimestamp = StreamEx.of(updatedStates.values())
                .filter(s -> Objects.nonNull(s.lastTimestamp))
                .mapToEntry(s -> YtSyncCommonUtil.extractShard(s.dbName), s -> s.lastTimestamp)
                .toMap();
        try {
            timeLagConsumer.accept(shardsLastTimestamp);
        } catch (Exception e) {
            logger.error("Got exception on sending time lag", e);
        }
    }

    /**
     * Очищает все закешированные данные
     */
    public void clear() {
        updatedStates.clear();
        upstreamStates.clear();
    }
}
