package ru.yandex.direct.binlogbroker.ytbootstrap.components;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.binlog.model.BinlogEvent;
import ru.yandex.direct.binlog.model.Operation;
import ru.yandex.direct.utils.Checked;
import ru.yandex.direct.utils.MonotonicTime;
import ru.yandex.direct.utils.NanoTimeClock;
import ru.yandex.direct.utils.io.RuntimeIoException;

@ParametersAreNonnullByDefault
class DatabaseYtBootstrapStateManager {

    private static final Logger logger = LoggerFactory.getLogger(DatabaseYtBootstrapStateManager.class);

    private static final Duration SAVE_STATE_OCCASIONALLY_DURATION = Duration.ofSeconds(30);
    private static final String STATE_FILE = "db_bootstrap_state.bin";

    private final Object stateLock = new Object();
    private final DatabaseYtBootstrapState state;
    private MonotonicTime lastSaveStateTime;

    DatabaseYtBootstrapStateManager(String source, int batchRows, int batchEvents) {
        this.state = new DatabaseYtBootstrapState(source, batchRows, batchEvents);
    }

    /**
     * Not thread safe
     */
    public DatabaseYtBootstrapState getState() {
        return state;
    }

    /**
     * Loads saved state.
     * If there is no saved state, returns null.
     * If saved state is for different source, returns null.
     */
    @Nullable
    static DatabaseYtBootstrapState loadState(String expectedSource) {
        try (ObjectInputStream is =
                     new ObjectInputStream(new FileInputStream(Paths.get(STATE_FILE).toFile()))) {
            DatabaseYtBootstrapState loadedState = (DatabaseYtBootstrapState) is.readObject();
            if (expectedSource.equals(loadedState.source)) {
                logger.debug("Loaded saved state");
                return loadedState;
            } else {
                logger.warn("Ignoring loaded state, because it is for different source '{}' (should be '{}')",
                        loadedState.source, expectedSource);
                return null;
            }
        } catch (FileNotFoundException ignore) {
            return null;
        } catch (IOException | ClassNotFoundException e) {
            throw new Checked.CheckedException("Failed to read saved state", e);
        }
    }

    void saveState() {
        final Path path = Paths.get(STATE_FILE).toAbsolutePath();
        try {
            File tmpFile = File.createTempFile(path.toString(), ".tmp", path.getParent().toFile());
            try {
                try (ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(tmpFile))) {
                    synchronized (stateLock) {
                        os.writeObject(state);
                    }
                }
                Files.move(tmpFile.toPath(), path,
                        StandardCopyOption.ATOMIC_MOVE, StandardCopyOption.REPLACE_EXISTING);
                lastSaveStateTime = NanoTimeClock.now();
                logger.debug("Saved current state");
            } catch (IOException ex) {
                if (!tmpFile.delete()) {
                    logger.error("Can't delete temp file: {}", tmpFile);
                }
                throw ex;
            }
        } catch (IOException ex) {
            throw new RuntimeIoException(ex);
        }
    }

    void saveStateOccasionally() {
        MonotonicTime now = NanoTimeClock.now();
        if (lastSaveStateTime == null || lastSaveStateTime.plus(SAVE_STATE_OCCASIONALLY_DURATION).isAtOrBefore(now)) {
            saveState();
        }
    }

    /**
     * Not thread safe
     */
    private DatabaseYtBootstrapState.TableState getTableState(String tableName) {
        return state.tables.computeIfAbsent(tableName, name -> new DatabaseYtBootstrapState.TableState());
    }

    @Nullable
    private Map<String, Object> getFirstPrimaryKey(List<BinlogEvent> batch) {
        for (BinlogEvent event : batch) {
            if (event.getOperation() != Operation.SCHEMA && !event.getRows().isEmpty()) {
                return event.getRows().get(0).getPrimaryKey();
            }
        }
        return null;
    }

    @Nullable
    private Map<String, Object> getLastPrimaryKey(List<BinlogEvent> batch) {
        for (int i = batch.size() - 1; i >= 0; i--) {
            final BinlogEvent event = batch.get(i);
            final List<BinlogEvent.Row> rows = event.getRows();
            if (event.getOperation() != Operation.SCHEMA && !rows.isEmpty()) {
                return rows.get(rows.size() - 1).getPrimaryKey();
            }
        }
        return null;
    }

    void batchQueued(List<BinlogEvent> batch) {
        Preconditions.checkArgument(batch.stream().map(BinlogEvent::getTable).collect(Collectors.toSet()).size() == 1,
                "All events in batch should be for the same table");
        String tableName = batch.get(0).getTable();
        final Map<String, Object> primaryKey = getFirstPrimaryKey(batch);
        if (primaryKey != null) {
            synchronized (stateLock) {
                final DatabaseYtBootstrapState.TableState tableState = getTableState(tableName);
                tableState.batchesInProgress.add(primaryKey);
                tableState.lastReadPrimaryKey = getLastPrimaryKey(batch);
                saveStateOccasionally();
            }
        }
    }

    void batchFinished(List<BinlogEvent> batch) {
        Preconditions.checkArgument(batch.stream().map(BinlogEvent::getTable).collect(Collectors.toSet()).size() == 1,
                "All events in batch should be for the same table");
        String tableName = batch.get(0).getTable();
        final Map<String, Object> primaryKey = getFirstPrimaryKey(batch);
        if (primaryKey != null) {
            synchronized (stateLock) {
                getTableState(tableName).batchesInProgress.remove(primaryKey);
                saveStateOccasionally();
            }
        }
    }
}
