package ru.yandex.direct.mysql;

import java.time.Duration;

import com.github.shyiko.mysql.binlog.event.QueryEventData;
import com.github.shyiko.mysql.binlog.event.RowsQueryEventData;

import ru.yandex.direct.mysql.schema.ServerSchema;
import ru.yandex.direct.utils.Interrupts;

public class TransactionsCountWatcher implements MySQLBinlogConsumer {
    private MySQLBinlogConsumer consumer;
    private long transactionsCount;

    public TransactionsCountWatcher(MySQLBinlogConsumer consumer) {
        this.consumer = consumer;
        this.transactionsCount = 0;
    }

    @Override
    public void onConnect(MySQLBinlogDataStreamer streamer) {
        consumer.onConnect(streamer);
    }

    @Override
    public void onDisconnect(MySQLBinlogDataStreamer streamer) {
        consumer.onDisconnect(streamer);
    }

    @Override
    public void onDDL(String gtid, QueryEventData data, ServerSchema before, ServerSchema after) {
        consumer.onDDL(gtid, data, before, after);
    }

    @Override
    public void onTransactionBegin(String gtid) {
        consumer.onTransactionBegin(gtid);
    }

    @Override
    public void onRowsQuery(RowsQueryEventData data, long timestamp) {
        consumer.onRowsQuery(data, timestamp);
    }

    @Override
    public void onInsertRows(MySQLSimpleData data) {
        consumer.onInsertRows(data);
    }

    @Override
    public void onUpdateRows(MySQLUpdateData data) {
        consumer.onUpdateRows(data);
    }

    @Override
    public void onDeleteRows(MySQLSimpleData data) {
        consumer.onDeleteRows(data);
    }

    @Override
    public void onTransactionCommit(String gtid) {
        consumer.onTransactionCommit(gtid);
        synchronized (this) {
            transactionsCount++;
            notifyAll();
        }
    }

    public synchronized long getTransactionsCount() {
        return transactionsCount;
    }

    public void waitForTransactionsCount(long minTransactionsCount, Duration timeout) {
        synchronized (this) {
            Interrupts.criticalTimeoutWait(
                    timeout,
                    remainingTimeout -> {
                        Interrupts.waitMillisNanos(this::wait).await(timeout);
                        return transactionsCount >= minTransactionsCount;
                    },
                    new IllegalStateException(
                            "Wait for transactions timed out. Expected " + minTransactionsCount +
                                    ", got only " + transactionsCount + " after " + timeout.toMillis() + " milliseconds."
                    )
            );
        }
    }
}
