package ru.yandex.chemodan.app.dataapi.core.dao.support;

import org.junit.Ignore;
import org.junit.runner.RunWith;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.TestExecutionListeners;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.bolts.function.Function;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.dao.DataApiShardResolver;
import ru.yandex.chemodan.app.dataapi.core.dao.ShardPartitionDataSource;
import ru.yandex.chemodan.app.dataapi.core.dao.ShardPartitionLocator;
import ru.yandex.chemodan.app.dataapi.core.dao.UserShardInfo;
import ru.yandex.chemodan.app.dataapi.core.dao.test.DataApiRunWithRandomTestRunner;
import ru.yandex.chemodan.app.dataapi.core.dao.test.DataApiShardDaoTestContextLoader;
import ru.yandex.chemodan.app.dataapi.web.ReadonlyException;
import ru.yandex.chemodan.ratelimiter.chunk.ChunkRateLimiter;
import ru.yandex.chemodan.util.jdbc.TransactionForEachTestMethodCreator;
import ru.yandex.commune.db.partition.PartitionResolver;
import ru.yandex.commune.db.partition.rewrite.DefaultPartitionQueryRewriter;
import ru.yandex.commune.db.partition.rewrite.PartitionLocator;
import ru.yandex.commune.db.shard2.Shard2;
import ru.yandex.commune.db.shard2.ShardManager2;
import ru.yandex.devtools.test.annotations.YaIgnore;
import ru.yandex.misc.db.q.SqlCondition;
import ru.yandex.misc.db.q.SqlLimits;
import ru.yandex.misc.db.q.SqlOrder;
import ru.yandex.misc.spring.jdbc.JdbcTemplate3;
import ru.yandex.misc.spring.jdbc.rewriteQuery.QueryRewriter;
import ru.yandex.misc.spring.jdbc.rewriteQuery.RewriteQueryJdbcTemplate;

/**
 * @author tolmalev
 */
@RunWith(DataApiRunWithRandomTestRunner.class)
@ContextConfiguration(
        loader = DataApiShardDaoTestContextLoader.class,
        initializers = DataApiShardDaoTestContextLoader.class
)
@TestExecutionListeners(value = {TransactionForEachTestMethodCreator.class})
@YaIgnore
@Ignore
public abstract class DataApiShardPartitionDaoSupport {

    protected final ShardManager2 shardManager2;
    private final DataApiShardResolver shardResolver;
    protected final PartitionResolver partitionResolver;
    private final DefaultPartitionQueryRewriter queryRewriter;

    protected DataApiShardPartitionDaoSupport(ShardPartitionDataSource dataSource) {
        this.shardManager2 = dataSource.shardManager;
        this.shardResolver = dataSource.shardResolver;
        this.partitionResolver = dataSource.partitionResolver;
        this.queryRewriter = new DefaultPartitionQueryRewriter(partitionResolver);
    }

    protected JdbcTemplate3 getJdbcTemplate(String userStr, Shard2 shard, PartitionLocator pl) {
        QueryRewriter qr =
                new AddCommentQueryRewriter(Option.of(queryRewriter.bind(pl)), Option.empty(), Option.of(userStr));

        return new RewriteQueryJdbcTemplate(shard.getJdbcTemplate3(), qr);
    }

    protected JdbcTemplate3 getJdbcTemplate(Shard2 shard, PartitionLocator pl) {
        QueryRewriter qr = queryRewriter.bind(pl);
        return new RewriteQueryJdbcTemplate(shard.getJdbcTemplate3(), qr);
    }

    protected JdbcTemplate3 getJdbcTemplate(ShardPartitionLocator shardPartition) {
        return getJdbcTemplate(shardManager2.getShard(shardPartition.shardId), shardPartition.partition);
    }

    protected JdbcTemplate3 getJdbcTemplate(DataApiUserId user, boolean forRead) {
        return getJdbcTemplate(user.asString(),
                getShard(user, forRead),
                PartitionLocator.byDiscriminant(user.discriminant()));
    }

    protected JdbcTemplate3 getJdbcTemplate(DataApiUserId user) {
        return getJdbcTemplate(user, false);
    }

    protected JdbcTemplate3 getReadJdbcTemplate(DataApiUserId user) {
        return getJdbcTemplate(user, true);
    }

    private Shard2 getShard(DataApiUserId user, boolean forRead) {
        if (TransactionUserShardIdHolder.holdsUser(user)) {
            return shardManager2.getShard(TransactionUserShardIdHolder.getHoldingShardId());
        }
        UserShardInfo shard = shardResolver.shardByUserId(user, forRead);

        if (!forRead && shard.userIsInRo) {
            throw new ReadonlyException(user);
        } else {
            return shardManager2.getShard(shard.shardId);
        }
    }

    public ListF<Integer> getShards() {
        return shardManager2.shards().map(s -> s.getShardInfo().getId());
    }

    protected ListF<ShardPartitionLocator> getShardPartitions(int shardId, String tableName) {
        int partitionsCount = partitionResolver.partitionsCountForTable(tableName);
        return Cf.range(0, partitionsCount).map(partition -> new ShardPartitionLocator(shardId, partition));
    }

    protected ListF<ShardPartitionLocator> getShardPartitions(String tableName) {
        ListF<Integer> shardIds = shardManager2.shards().map(s -> s.getShardInfo().getId());
        int partitionsCount = partitionResolver.partitionsCountForTable(tableName);

        ListF<ShardPartitionLocator> result = Cf.arrayList();

        for (int shardId : shardIds) {
            for (int partNo : Cf.range(0, partitionsCount)) {
                result.add(new ShardPartitionLocator(shardId, partNo));
            }
        }
        return result.unmodifiable();
    }

    public void deleteAllByChunks(
            DataApiUserId uid, String tableName, ListF<String> primaryKeys,
            SqlCondition condition, ChunkRateLimiter rateLimiter)
    {
        deleteAllByChunks(uid, tableName, primaryKeys, condition, Option.empty(), rateLimiter);
    }

    public void deleteAllByChunks(
            DataApiUserId uid, String tableName, ListF<String> primaryKeys,
            SqlCondition condition, SqlOrder order, ChunkRateLimiter rateLimiter)
    {
        deleteAllByChunks(uid, tableName, primaryKeys, condition, Option.of(order), rateLimiter);
    }

    public void deleteAllByChunks(
            DataApiUserId uid, String tableName, ListF<String> primaryKeys,
            SqlCondition condition, Option<SqlOrder> order, ChunkRateLimiter rateLimiter)
    {
        JdbcTemplate3 template = getJdbcTemplate(uid);
        String primaryKey = primaryKeys.mkString(", ");

        Function<Integer, Boolean> deleteChunk = chunkSize -> {
            SqlLimits limits = SqlLimits.first(chunkSize);
            String subquery = "SELECT " + primaryKey + " FROM " + tableName +
                    condition.whereSql() + " " +
                    order.map(o -> o.toSql() + " ").getOrElse("") +
                    limits.toMysqlLimits() + " FOR UPDATE";
            String query = "DELETE FROM " + tableName + " WHERE (" + primaryKey + ") IN (" + subquery + ")";

            return template.update(query, condition.args()) == chunkSize;
        };

        while (true) {
            if (!rateLimiter.acquirePermitAndExecute(deleteChunk)) {
                break;
            }
        }
    }

}
