package ru.yandex.chemodan.app.dataapi.core.dao.support;

import org.joda.time.Duration;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionException;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import org.springframework.transaction.support.DefaultTransactionStatus;

import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.dao.DataApiShardResolver;
import ru.yandex.chemodan.app.dataapi.core.dao.UserShardId;
import ru.yandex.chemodan.app.dataapi.core.dao.UserShardInfo;
import ru.yandex.chemodan.app.dataapi.web.ReadonlyException;
import ru.yandex.chemodan.util.jdbc.logging.LoggingQueryInterceptorConfiguration;
import ru.yandex.commune.db.shard2.Shard2;
import ru.yandex.commune.db.shard2.ShardManager2;
import ru.yandex.misc.ThreadLocalX;
import ru.yandex.misc.lang.StringUtils;
import ru.yandex.misc.log.mlf.Level;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.monica.annotation.GroupByDefault;
import ru.yandex.misc.monica.annotation.MonicaContainer;
import ru.yandex.misc.monica.annotation.MonicaMetric;
import ru.yandex.misc.monica.core.blocks.Instrument;
import ru.yandex.misc.monica.core.name.MetricGroupName;
import ru.yandex.misc.monica.core.name.MetricName;
import ru.yandex.misc.monica.util.measure.Measured;
import ru.yandex.misc.monica.util.measure.Measurer;
import ru.yandex.misc.time.TimeUtils;

/**
 * @author tolmalev
 */
public class ShardedTransactionManager implements MonicaContainer {

    private final Logger logger;

    private final ShardManager2 shardManager;
    private final DataApiShardResolver shardResolver;
    private final LoggingQueryInterceptorConfiguration loggingQueryConf;

    @MonicaMetric
    @GroupByDefault
    private final Instrument get = new Instrument();
    @MonicaMetric
    @GroupByDefault
    private final Instrument commit = new Instrument();
    @MonicaMetric
    @GroupByDefault
    private final Instrument rollback = new Instrument();

    private final ThreadLocalX<Boolean> hadCommitAttempt = new ThreadLocalX<>();

    public ShardedTransactionManager(
            ShardManager2 shardManager, DataApiShardResolver shardResolver,
            LoggingQueryInterceptorConfiguration loggingQueryConf)
    {
        this.shardManager = shardManager;
        this.shardResolver = shardResolver;
        this.loggingQueryConf = loggingQueryConf;
        this.logger = loggingQueryConf.logger;
    }

    public TransactionStatus getTransaction(DataApiUserId uid) {
        return getTransaction(uid, new DefaultTransactionDefinition());
    }

    public TransactionStatus getTransaction(final DataApiUserId uid, final TransactionDefinition definition) {
        UserShardInfo userShard = getUserShard(uid, definition.isReadOnly());
        TransactionUserShardIdHolder.set(userShard.getUserShardId());
        try {
            hadCommitAttempt.set(false);
            return doGetTransaction(userShard.getUserShardId(), definition);

        } catch (Throwable t) {
            TransactionUserShardIdHolder.remove();
            throw t;
        }
    }

    private TransactionStatus doGetTransaction(UserShardId userShard, TransactionDefinition definition) {
        Measured<TransactionStatus> measured = Measurer.I.measure(
                () -> {
                    Shard2 shard = shardManager.getShard(userShard.shardId);
                    logger.trace("Creating transaction: {}", definition);
                    return shard.getTransactionManager().getTransaction(definition);
                });

        get.update(measured.info());
        boolean successful = measured.info().isSuccessful();
        if (successful) {
            TransactionStatus status = measured.cont();
            log(status, measured.info().elapsed(), successful, "Create transaction: {}, took: {}");
            return status;
        } else {
            log(definition, measured.info().elapsed(), successful, "Create transaction: {}, took: {}");
            return measured.cont();
        }
    }

    public void commit(final DataApiUserId uid, final TransactionStatus status)
            throws TransactionException
    {
        TransactionUserShardIdHolder.checkHoldsUser(uid);
        hadCommitAttempt.set(true);

        UserShardInfo userShard = getUserShard(uid, isReadOnly(status));
        TransactionUserShardIdHolder.checkUserShardNotChanged(userShard.getUserShardId());

        if (userShard.userIsInRo && !isReadOnly(status)) {
            throw new ReadonlyException(uid);
        }
        doCommit(userShard.getUserShardId(), status);

        TransactionUserShardIdHolder.remove();
    }

    private void doCommit(UserShardId userShard, TransactionStatus status) {
        Measured<?> measured = Measurer.I.measure(() -> {
            Shard2 shard = shardManager.getShard(userShard.shardId);
            logger.trace("Committing transaction: {}", status);
            shard.getTransactionManager().commit(status);
            return null;
        });

        commit.update(measured.info());
        boolean successful = measured.info().isSuccessful();

        log(status, measured.info().elapsed(), successful, "Commit transaction: {}, took: {}");
        measured.cont();
    }

    public void rollback(final DataApiUserId uid, final TransactionStatus status)
            throws TransactionException
    {
        TransactionUserShardIdHolder.checkHoldsUser(uid);
        try {
            doRollback(TransactionUserShardIdHolder.getO().get(), status);

            if (!hadCommitAttempt.getO().isSome(true)) {
                TransactionUserShardIdHolder.checkUserShardNotChanged(getUserShard(uid, isReadOnly(status)).getUserShardId());
            }
        } finally {
            TransactionUserShardIdHolder.remove();
        }
    }

    private void doRollback(UserShardId userShard, TransactionStatus status) {
        Measured<?> measured = Measurer.I.measure(() -> {
            Shard2 shard = shardManager.getShard(userShard.shardId);

            logger.trace("Rolling back transaction: {}", status);
            shard.getTransactionManager().rollback(status);

            return null;
        });

        rollback.update(measured.info());
        boolean successful = measured.info().isSuccessful();
        log(status, measured.info().elapsed(), successful, "Roll back transaction: {}, took: {}");

        measured.cont();
    }

    private UserShardInfo getUserShard(DataApiUserId uid, boolean readOnly) {
        return shardResolver.shardByUserId(uid, readOnly);
    }

    private void log(Object firstArg, Duration elapsed, boolean isSuccessful, String message) {
        Level level = Level.TRACE;
        if (!isSuccessful) {
            level = Level.ERROR;
            message = "Failed to " + StringUtils.decapitalize(message);
        }
        if (elapsed.getMillis() > loggingQueryConf.longThreshold.get()) {
            if (level == Level.DEBUG) {
                level = Level.WARN;
            }
            message = "(LONG) " + message;
        }
        logger.log(level, message, firstArg,
                TimeUtils.millisecondsToSecondsString(elapsed.getMillis()));
    }

    private static boolean isReadOnly(TransactionStatus status) {
        return status instanceof DefaultTransactionStatus && ((DefaultTransactionStatus) status).isReadOnly();
    }

    @Override
    public MetricGroupName groupName(String instanceName) {
        return new MetricGroupName(
                "dataapi",
                new MetricName("dataapi", "transaction"),
                "Transaction manager"
        );
    }
}
