package ru.yandex.chemodan.app.dataapi.core.mdssnapshot;

import org.joda.time.Instant;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.chemodan.app.dataapi.api.db.Database;
import ru.yandex.chemodan.app.dataapi.api.user.DataApiUserId;
import ru.yandex.chemodan.app.dataapi.core.dao.JdbcDaoUtils;
import ru.yandex.chemodan.app.dataapi.core.dao.ShardPartitionDataSource;
import ru.yandex.chemodan.app.dataapi.core.dao.support.DataApiShardPartitionDaoSupport;
import ru.yandex.chemodan.ratelimiter.chunk.ChunkRateLimiter;
import ru.yandex.commune.db.partition.rewrite.PartitionLocator;
import ru.yandex.commune.db.shard2.Shard2;
import ru.yandex.commune.test.random.RunWithRandomTest;
import ru.yandex.misc.db.q.ConditionUtils;
import ru.yandex.misc.db.q.SqlCondition;

/**
 * @author Denis Bakharev
 */
public class MdsSnapshotReferenceJdbcDaoImpl extends DataApiShardPartitionDaoSupport
        implements MdsSnapshotReferenceJdbcDao
{
    public MdsSnapshotReferenceJdbcDaoImpl(ShardPartitionDataSource dataSource) {
        super(dataSource);
    }

    @Override
    @RunWithRandomTest
    public void insert(MdsSnapshotReference mdsSnapshotReference) {
        insertBatch(mdsSnapshotReference.uid, Cf.list(mdsSnapshotReference));
    }

    @Override
    public void insertBatch(DataApiUserId uid, ListF<MdsSnapshotReference> references) {
        JdbcDaoUtils.updateRowOrBatch(getJdbcTemplate(uid), ""
                + "INSERT INTO database_snapshots_references"
                + " (database_handle, database_rev, last_request_time, mds_key, user_id)"
                + " VALUES (?, ?, ?, ?, ?)",
                references, (mdsSnapshotReference) -> Cf.list(
                        mdsSnapshotReference.databaseHandle,
                        mdsSnapshotReference.databaseRev,
                        mdsSnapshotReference.lastRequestTime,
                        mdsSnapshotReference.mdsKey.getOrNull(),
                        mdsSnapshotReference.uid.toString()));
    }

    @Override
    @RunWithRandomTest
    public int updateLastRequestTime(Instant lastRequestTime, Database database) {
        int updatedRowsCount = getJdbcTemplate(database.uid).update(
                "UPDATE database_snapshots_references SET last_request_time = ? WHERE database_handle = ?"
                + " AND database_rev = ? AND user_id = ?",
                lastRequestTime,
                database.handleValue(),
                database.rev,
                database.uid.toString());
        return updatedRowsCount;
    }

    @Override
    @RunWithRandomTest
    public int updateMdsKey(String mdsKey, Database database) {
        int updatedRowsCount = getJdbcTemplate(database.uid).update(
                "UPDATE database_snapshots_references SET mds_key = ? WHERE database_handle = ? AND database_rev = ?"
                + " AND user_id = ?",
                mdsKey,
                database.handleValue(),
                database.rev,
                database.uid.toString());
        return updatedRowsCount;
    }

    @Override
    @RunWithRandomTest
    public Option<MdsSnapshotReference> find(DataApiUserId uid, String handle, long revision) {
        return getReadJdbcTemplate(uid).queryForOption(
                "SELECT * from database_snapshots_references WHERE database_handle = ? AND database_rev = ?"
                + " AND user_id = ?",
                new MdsSnapshotReferenceMapper(),
                handle,
                revision,
                uid.toString());
    }

    @Override
    @RunWithRandomTest
    public ListF<MdsSnapshotReference> find(DataApiUserId uid, ListF<String> handles) {
        if (handles.isEmpty()) return Cf.list();

        SqlCondition cond = ConditionUtils.column("database_handle").inSet(handles)
                .and(ConditionUtils.column("user_id").eq(uid.toString()));

        return getReadJdbcTemplate(uid).query(
                "SELECT * from database_snapshots_references" + cond.whereSql(),
                new MdsSnapshotReferenceMapper(),
                cond.args());
    }

    @Override
    @RunWithRandomTest
    public int delete(MdsSnapshotReference snapshotReference) {
        int deletedRows = getReadJdbcTemplate(snapshotReference.uid).update(
                "DELETE from database_snapshots_references WHERE database_handle = ? AND database_rev = ?"
                + " AND user_id = ?",
                snapshotReference.databaseHandle,
                snapshotReference.databaseRev,
                snapshotReference.uid.toString());
        return deletedRows;
    }

    @Override
    @RunWithRandomTest
    public void delete(DataApiUserId uid, ListF<String> handles) {
        if (handles.isEmpty()) return;

        SqlCondition cond = ConditionUtils.column("database_handle").inSet(handles)
                .and(ConditionUtils.column("user_id").eq(uid.toString()));

        getJdbcTemplate(uid).update("DELETE FROM database_snapshots_references" + cond.whereSql(), cond.args());
    }

    @Override
    @RunWithRandomTest
    public void delete(DataApiUserId uid, ListF<String> handles, ChunkRateLimiter limiter) {
        if (handles.isEmpty()) return;

        SqlCondition condition = ConditionUtils.column("database_handle").inSet(handles)
                .and(ConditionUtils.column("user_id").eq(uid.toString()));
        deleteAllByChunks(uid, "database_snapshots_references", Cf.list("database_handle", "database_rev"),
                condition, limiter);
    }

    @Override
    @RunWithRandomTest
    public ListF<MdsSnapshotReference> findAllWithCreationTimeLessThan(DataApiUserId uid, Instant instant) {
        ListF<MdsSnapshotReference> result = Cf.arrayList();
        for (Shard2 shard : shardManager2.shards()) {
            ListF<MdsSnapshotReference> oneShardResult = getJdbcTemplate(shard, PartitionLocator.noRewrite()).query(
                    "SELECT * from database_snapshots_references WHERE" +
                            " user_id = '" + uid.toString() + "'" +
                            " AND last_request_time < ?",
                    new MdsSnapshotReferenceMapper(),
                    instant);
            result.addAll(oneShardResult);
        }

        return result;
    }

    @Override
    @RunWithRandomTest
    public ListF<MdsSnapshotReference> findAllWithCreationTimeLessThan(Instant instant) {
        ListF<MdsSnapshotReference> result = Cf.arrayList();
        for (Shard2 shard : shardManager2.shards()) {
            ListF<MdsSnapshotReference> oneShardResult = findAllWithCreationTimeLessThan(instant, shard);
            result.addAll(oneShardResult);
        }

        return result;
    }

    public ListF<MdsSnapshotReference> findAllWithCreationTimeLessThan(Instant instant, Shard2 shard) {
        return getJdbcTemplate(shard, PartitionLocator.noRewrite()).query(
                "SELECT * from database_snapshots_references WHERE last_request_time < ?",
                new MdsSnapshotReferenceMapper(),
                instant);
    }

    public ListF<ListF<MdsSnapshotReference>> findAllWithCreationTimeLessThanPartitionedByShard(Instant instant) {
        return shardManager2.shards().map(shard -> findAllWithCreationTimeLessThan(instant, shard));
    }
}
