package ru.yandex.chemodan.util.jdbc;

import java.sql.Connection;
import java.sql.SQLException;

import javax.sql.DataSource;

import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.bolts.collection.Unit;
import ru.yandex.chemodan.util.jdbc.logging.LastAccessedDsAwareQueryInterceptor;
import ru.yandex.chemodan.util.jdbc.logging.LastAccessedDsSource;
import ru.yandex.commune.alive2.location.Location;
import ru.yandex.commune.alive2.location.LocationResolver;
import ru.yandex.misc.ThreadLocalX;
import ru.yandex.misc.db.DataSourceUtils;
import ru.yandex.misc.db.LocationExposedDataSource;
import ru.yandex.misc.db.SqlAdminAwareDataSource;
import ru.yandex.misc.db.UrlExposedConnection;
import ru.yandex.misc.db.masterSlave.MasterSlavePolicy;
import ru.yandex.misc.db.masterSlave.dynamic.DynamicDatasourceHolder;
import ru.yandex.misc.db.masterSlave.dynamic.DynamicMasterSlaveDataSource;
import ru.yandex.misc.db.postgres.PgBouncerFamiliarConnection;
import ru.yandex.misc.db.url.JdbcUrl;
import ru.yandex.misc.io.IoFunction;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;

/**
 * @author osidorkin
 * @author fromen
 */
public class DcAwareDynamicMasterSlaveDataSource extends DynamicMasterSlaveDataSource implements
        LastAccessedDsSource, SqlAdminAwareDataSource
{
    private static final Logger logger = LoggerFactory.getLogger(DcAwareDynamicMasterSlaveDataSource.class);
    private static final String MAN_DC_NAME = "man";

    private final LocationResolver locationResolver;
    private final int replicationMaximumLag;
    private final long delayBetweenPingsMillis;

    private final LastAccessedDsAwareQueryInterceptor lastAccessedDsAwareQueryInterceptor;
    private final ThreadLocalX<Unit> forSqlAdmin = new ThreadLocalX<>();

    private volatile Option<String> masterUrl = Option.empty();

    public DcAwareDynamicMasterSlaveDataSource(
            ListF<DataSource> dataSources, Options options, LocationResolver locationResolver,
            int replicationMaximumLag, long delayBetweenPingsMillis,
            LastAccessedDsAwareQueryInterceptor lastAccessedDsAwareQueryInterceptor,
            IoFunction<Connection, Connection> connectionHandler)
    {
        super(dataSources, options, connectionHandler);
        this.locationResolver = locationResolver;
        this.replicationMaximumLag = replicationMaximumLag;
        this.delayBetweenPingsMillis = delayBetweenPingsMillis;
        this.lastAccessedDsAwareQueryInterceptor = lastAccessedDsAwareQueryInterceptor;
    }

    @Override
    public Connection getConnectionForSqlAdmin() throws SQLException {
        forSqlAdmin.set(Unit.U);
        try {
            return getConnection();
        } finally {
            forSqlAdmin.remove();
        }
    }

    @Override
    protected ListF<DynamicDatasourceHolder> getSlaveCheckers(ListF<DynamicDatasourceHolder> holders) {
        ListF<DynamicDatasourceHolder> slaves = super.getSlaveCheckers(holders);
        return getDcAwareCheckers(slaves);
    }

    private ListF<DynamicDatasourceHolder> getDcAwareCheckers(ListF<DynamicDatasourceHolder> checkers) {
        if (!checkers.isEmpty()) {
            DataSource dataSource = checkers.first().getDataSource();
            if (dataSource instanceof LocationExposedDataSource) {
                Location dcLocation = ((LocationExposedDataSource) dataSource).getLocation();
                if (!dcLocation.dcName.equals(locationResolver.resolveLocation().dcName)) {
                    checkers = checkers.shuffle();

                    //убираем man в конец
                    Tuple2<ListF<DynamicDatasourceHolder>, ListF<DynamicDatasourceHolder>> t =
                            checkers.partition(DcAwareDynamicMasterSlaveDataSource::isMan);
                    return t._2.plus(t._1);
                }
            }
        }
        return checkers;
    }

    private static boolean isMan(DynamicDatasourceHolder checker) {
        DataSource dataSource = checker.getDataSource();
        if (dataSource instanceof LocationExposedDataSource) {
            Location dcLocation = ((LocationExposedDataSource) dataSource).getLocation();
            return dcLocation.dcName.isPresent() && dcLocation.dcName.map(MAN_DC_NAME::equals).getOrElse(false);
        }

        return false;
    }

    protected ListF<DynamicDatasourceHolder> getCheckers(ListF<DynamicDatasourceHolder> checkers) {
        return getDcAwareCheckers(checkers);
    }

    @Override
    protected Connection doGetConnectionForPolicy(MasterSlavePolicy policy) {
        try {
            Connection connection = super.doGetConnectionForPolicy(policy);

            JdbcUrl url = UrlExposedConnection.getUrl(connection);

            lastAccessedDsAwareQueryInterceptor.getLastAccessedDsHolder().set(
                    url, masterUrl.exists(url.getShortName()::equals));

            return new PgBouncerFamiliarConnection(connection, Option.of(lastAccessedDsAwareQueryInterceptor));

        } catch (Throwable t) {
            lastAccessedDsAwareQueryInterceptor.getLastAccessedDsHolder().remove();
            throw t;
        }
    }

    @Override
    protected DataSourceCheckWorker newCheckWorker(DataSource ds) {
        return new DataSourceReplicationLagAwareCheckWorker(ds);
    }

    @Override
    protected Connection getConnection(DynamicDatasourceHolder holder) throws SQLException {
        DataSource dataSource = holder.getDataSource();

        return forSqlAdmin.isSet() && dataSource instanceof SqlAdminAwareDataSource
                ? ((SqlAdminAwareDataSource) dataSource).getConnectionForSqlAdmin()
                : dataSource.getConnection();
    }

    private void checkDatabaseReplicationLag(DataSource dataSource) {
        JdbcTemplate3 jdbcTemplate = new JdbcTemplate3(dataSource);
        int lag = jdbcTemplate.query("SELECT extract(epoch from clock_timestamp() - ts) FROM repl_mon ",
                (resultSet, i) -> resultSet.getInt(1)).first();
        logger.debug("DataSource {} has lag: {}s",
                DataSourceUtils.shortUrlOrSomething(dataSource), lag);
        if (lag >= replicationMaximumLag) {
            if (!isFirstCheckCompleted() || hasAvailableMasterDataSource()) {
                throw new RuntimeException("DataSource " + DataSourceUtils.shortUrlOrSomething(dataSource) +
                        " has big replication lag " + lag + "s");
            } else {
                logger.warn("No available master found. Ignoring replication lag of " + lag + "s");
            }
        }
    }

    protected class DataSourceReplicationLagAwareCheckWorker extends DataSourceCheckWorker {
        protected DataSourceReplicationLagAwareCheckWorker(DataSource dataSource) {
            super(dataSource);
        }

        @Override
        protected long delayBetweenExecutionsMillis() {
            return delayBetweenPingsMillis;
        }

        @Override
        protected void ping() throws Exception {
            super.ping();
            if (!isMaster) {
                if (masterUrl.isSome(getUrl())) {
                    masterUrl = Option.empty();
                }
                checkDatabaseReplicationLag(getCheckerDataSource());
            } else {
                masterUrl = Option.of(getUrl());
            }
        }
    }

    @Override
    public LastAccessedDsAwareQueryInterceptor getLastAccessedDsInterceptor() {
        return lastAccessedDsAwareQueryInterceptor;
    }
}
