package ru.yandex.travel.workflow.ha;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Duration;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;

import javax.sql.DataSource;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class DBMasterLockManager implements MasterLockManager {

    private final DataSource dataSource;

    private static int LOCK_ID = 1;

    private ScheduledExecutorService scheduledExecutorService;

    private ScheduledFuture<?> heartbeatFuture;

    private Duration heartbeatDuration;

    private Duration initialPingDelay;

    private int pingQueryTimeout;

    public DBMasterLockManager(DataSource dataSource,
                               Duration initialPingDelay,
                               Duration heartbeatDuration,
                               int pingQueryTimeout) {
        this.dataSource = dataSource;
        this.scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
                new ThreadFactoryBuilder()
                        .setNameFormat("Pinger")
                        .setDaemon(true)
                        .build()
        );
        this.heartbeatDuration = heartbeatDuration;
        this.initialPingDelay = initialPingDelay;
        this.pingQueryTimeout = pingQueryTimeout;
    }


    @Override
    public synchronized boolean acquireLock(LockLostCallback lockLostCallback) {
        Connection connection = null;
        Preconditions.checkState(heartbeatFuture == null, "There must be no active heartbeat");
        boolean lockAcquired = false;
        try {
            connection = dataSource.getConnection();
            Statement stmt = connection.createStatement();
            ResultSet rs = stmt.executeQuery(
                    String.format("SELECT pg_try_advisory_lock(%s);", LOCK_ID)
            );
            if (rs == null) {
                log.error("Lock not acquired: no result set returned");
                return false;
            }
            if (!rs.next()) {
                log.error("Lock not acquired: result set is empty");
            }
            boolean result = rs.getBoolean(1);
            if (result) {
                log.info("Lock acquired");
                heartbeatFuture = scheduledExecutorService.scheduleWithFixedDelay(
                        new HeartbeatChecker(connection, () -> {
                            cleanUpHeartBeat();
                            lockLostCallback.lockLost();
                        }, pingQueryTimeout), initialPingDelay.toNanos(),
                        heartbeatDuration.toNanos(), TimeUnit.NANOSECONDS
                );
                lockAcquired = true;
            } else {
                log.debug("Lock not acquired: lock is already acquired by the other process");
                lockAcquired = false;
            }
        } catch (Exception e) {
            log.error("Lock not acquired: an exception occurred while acquiring the lock", e);
        } finally {
            if (!lockAcquired && connection != null) {
                try {
                    connection.close();
                } catch (Exception e) {
                    log.error("Error closing connection", e);
                }
            }
        }
        return lockAcquired;
    }

    private synchronized void cleanUpHeartBeat() {
        heartbeatFuture.cancel(true);
        heartbeatFuture = null;
    }

    private static final class HeartbeatChecker implements Runnable {
        private final Connection connection;
        private final int pingQueryTimout;

        private LockLostCallback lockLostCallback;

        public HeartbeatChecker(Connection connection, LockLostCallback lockLostCallback, int pingQueryTimout) {
            this.connection = connection;
            this.lockLostCallback = lockLostCallback;
            this.pingQueryTimout = pingQueryTimout;
        }


        @Override
        public void run() {
            try {
                Statement stmt = connection.createStatement();
                stmt.setQueryTimeout(pingQueryTimout);
                ResultSet rs = stmt.executeQuery("SELECT 1;");
                if (!rs.next()) {
                    throw new IllegalStateException("SELECT 1; returned no rows");
                }
                int value = rs.getInt(1);
                if (value != 1) {
                    throw new IllegalStateException(String.format("SELECT 1; returned different value: %s", value));
                }
            } catch (Exception e) {
                lockLostCallback.lockLost();
                log.error("Lock lost: an exception occurred while checking the heartbeat", e);
                try {
                    connection.close();
                } catch (Exception e1) {
                    log.error("Error closing connection", e1);
                }
            }
        }
    }

}
