package ru.yandex.solomon.coremon.tasks.deleteMetrics;

import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Stream;

import com.google.common.collect.Streams;
import io.grpc.Status;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import ru.yandex.coremon.api.task.DeleteMetricsParams;
import ru.yandex.coremon.api.task.DeleteMetricsRollbackProgress.RollbackDeletedMetricsProgress;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.solomon.coremon.meta.CoremonMetric;
import ru.yandex.solomon.coremon.meta.FileCoremonMetric;
import ru.yandex.solomon.coremon.meta.db.memory.InMemoryDeletedMetricsDao;
import ru.yandex.solomon.coremon.meta.db.memory.InMemoryMetricsDao;
import ru.yandex.solomon.coremon.meta.db.memory.InMemoryMetricsDaoFactory;
import ru.yandex.solomon.coremon.meta.service.MetabaseShardConf;
import ru.yandex.solomon.coremon.meta.service.MetabaseShardResolverStub;
import ru.yandex.solomon.metrics.client.StockpileClientStub;
import ru.yandex.solomon.util.future.RetryConfig;

import static java.util.Collections.shuffle;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static java.util.concurrent.ForkJoinPool.commonPool;
import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toUnmodifiableList;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThat;
import static ru.yandex.solomon.coremon.tasks.deleteMetrics.DeleteMetricsAssert.assertMetrics;
import static ru.yandex.solomon.coremon.tasks.deleteMetrics.DeleteMetricsRandom.metricInSpShard;
import static ru.yandex.solomon.coremon.tasks.deleteMetrics.DeleteMetricsRandom.shardsWithNumIdUpTo;
import static ru.yandex.solomon.util.CloseableUtils.close;

/**
 * @author Stanislav Kashirin
 */
public class RollbackDeletedMetricsTest {

    private static final int TEST_FIND_METRICS_LIMIT = 1000;

    private static final int MAX_NUM_ID = 3;
    private static final List<MetabaseShardConf> SHARDS = shardsWithNumIdUpTo(MAX_NUM_ID);

    @Rule
    public Timeout globalTimeout = Timeout.builder()
        .withTimeout(1, MINUTES)
        .withLookingForStuckThread(true)
        .build();

    private RetryConfig retryConfig;
    private InMemoryDeletedMetricsDao deletedMetricsDao;
    private InMemoryMetricsDaoFactory metricsDaoFactory;
    private MetabaseShardResolverStub shardResolver;

    @Before
    public void setUp() {
        retryConfig = RetryConfig.DEFAULT
            .withNumRetries(Integer.MAX_VALUE)
            .withMaxDelay(0);

        var stockpileClient = new StockpileClientStub(commonPool());

        deletedMetricsDao = new InMemoryDeletedMetricsDao();

        metricsDaoFactory = new InMemoryMetricsDaoFactory();
        metricsDaoFactory.setSuspendShardInitOnCreate(true);

        shardResolver = new MetabaseShardResolverStub(
            SHARDS,
            metricsDaoFactory,
            stockpileClient);
    }

    @After
    public void tearDown() {
        close(metricsDaoFactory, shardResolver);
    }

    @Test
    public void alreadyCompleted() {
        // arrange
        deletedMetricsDao.beforeSupplier = unavailable();

        var progress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(777)
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();

        var proc = rollbackDeletedMetrics(params(), progress);

        // act
        proc.start().join();

        // assert
        assertEquals(progress, proc.progress());
    }

    @Test
    public void shardIsNotLocalAnymore() {
        // arrange
        deletedMetricsDao.beforeSupplier = unavailable();

        var params = params().toBuilder().setNumId(666).build();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        // act
        proc.start().join();

        // assert
        assertEquals(RollbackDeletedMetricsProgress.getDefaultInstance(), proc.progress());
    }

    @Test
    public void shardIsNotReady() {
        // arrange
        deletedMetricsDao.beforeSupplier = unavailable();

        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        // act
        proc.start().join();

        // assert
        assertEquals(RollbackDeletedMetricsProgress.getDefaultInstance(), proc.progress());
    }

    @Test
    public void nothingToRollback() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        ensureShardReady(params.getNumId());

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(0)
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());
    }

    @Test
    public void rollbackOnSmallOperation() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());

        var alreadyExistingMetrics = Stream.generate(() -> metricInSpShard(666))
            .limit(random().nextInt(3, 5))
            .collect(toUnmodifiableList());

        ensureMetricsInDao(params.getNumId(), alreadyExistingMetrics);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(relevantMetrics.size())
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());

        assertMetrics(irrelevantMetrics, deletedMetricsDao.metrics());

        var expectedExisting = Stream.concat(
                relevantMetrics.stream(),
                alreadyExistingMetrics.stream())
            .collect(toList());
        assertMetrics(expectedExisting, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void rollbackOnBigOperation() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(2000, 5000))
            .collect(toList());
        shuffle(relevantMetrics);

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(500, 1000))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(500, 1000))
            .collect(toUnmodifiableList());
        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());


        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(relevantMetrics.size())
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());

        assertMetrics(irrelevantMetrics, deletedMetricsDao.metrics());
        assertMetrics(relevantMetrics, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void rollbackUpToQuota() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(10_000)
            .collect(toList());
        shuffle(relevantMetrics);

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(500, 1000))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(500, 1000))
            .collect(toUnmodifiableList());

        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);

        shardResolver.resolveShard(params.getNumId()).setMaxFileMetrics(2500);
        ensureShardReady(params.getNumId());

        // act
        var status = proc.start().thenApply(i -> Status.OK).exceptionally(Status::fromThrowable).join();

        // assert
        assertEquals(Status.Code.RESOURCE_EXHAUSTED, status.getCode());

        var progress = proc.progress();
        assertFalse(progress.getComplete());
        assertEquals(relevantMetrics.size(), progress.getTotalMetrics());
        assertThat(progress.getStillDeletedMetrics(), lessThan(relevantMetrics.size()));
        assertThat(progress.getProgress(), allOf(greaterThanOrEqualTo(0.2), lessThanOrEqualTo(0.3)));

        assertThat(
            getMetricsDao(params.getNumId()).metrics().size(),
            allOf(
                greaterThanOrEqualTo(2000),
                lessThanOrEqualTo(3000)));
    }

    @Test
    public void rollbackCollisionWithSameStockpileIds() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var collisionMetrics = relevantMetrics.subList(0, 5);
        var otherRelevantMetrics = relevantMetrics.subList(5, relevantMetrics.size());

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());

        var otherExistingMetrics = Stream.generate(() -> metricInSpShard(666))
            .limit(random().nextInt(3, 5));
        var alreadyExistingMetrics = Stream.concat(
                collisionMetrics.stream(),
                otherExistingMetrics)
            .collect(toList());

        ensureMetricsInDao(params.getNumId(), alreadyExistingMetrics);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(relevantMetrics.size())
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());

        assertMetrics(irrelevantMetrics, deletedMetricsDao.metrics());

        var expectedExisting = Stream.concat(
                otherRelevantMetrics.stream(),
                alreadyExistingMetrics.stream())
            .collect(toList());
        assertMetrics(expectedExisting, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void rollbackCollisionWithDiffStockpileIds() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var collisionRelevantMetrics = relevantMetrics.subList(0, 5);
        var otherRelevantMetrics = relevantMetrics.subList(5, relevantMetrics.size());

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());

        var otherExistingMetrics = Stream.generate(() -> metricInSpShard(666))
            .limit(random().nextInt(3, 5));
        var collisionExistingMetrics02 = collisionRelevantMetrics.subList(0, 2).stream()
            .<CoremonMetric>map(
                m -> new FileCoremonMetric(
                    m.getShardId() + 1000,
                    m.getLocalId(),
                    m.getLabels(),
                    m.getCreatedAtSeconds(),
                    m.getType()));
        var collisionExistingMetrics25 = collisionRelevantMetrics.subList(2, 5).stream()
            .<CoremonMetric>map(
                m -> new FileCoremonMetric(
                    m.getShardId(),
                    m.getLocalId() + 1000,
                    m.getLabels(),
                    m.getCreatedAtSeconds(),
                    m.getType()));

        var alreadyExistingMetrics = Stream.of(
                collisionExistingMetrics02,
                collisionExistingMetrics25,
                otherExistingMetrics)
            .flatMap(identity())
            .collect(toList());

        ensureMetricsInDao(params.getNumId(), alreadyExistingMetrics);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        // act
        var status = proc.start().thenApply(i -> Status.OK).exceptionally(Status::fromThrowable).join();

        // assert
        assertEquals(Status.Code.ALREADY_EXISTS, status.getCode());

        var progress = proc.progress();
        assertFalse(progress.getComplete());
        assertEquals(relevantMetrics.size(), progress.getTotalMetrics());
        assertEquals(progress.getStillDeletedMetrics(), collisionRelevantMetrics.size());
        assertThat(progress.getProgress(), greaterThan(0.0));

        var expectedDeleted = Stream.concat(
                collisionRelevantMetrics.stream(),
                irrelevantMetrics.stream())
            .collect(toList());
        assertMetrics(expectedDeleted, deletedMetricsDao.metrics());

        var expectedExisting = Stream.concat(
                otherRelevantMetrics.stream(),
                alreadyExistingMetrics.stream())
            .collect(toList());
        assertMetrics(expectedExisting, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void resumeRollback() {
        // arrange
        var params = params();

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());

        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        var progress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(false)
            .setTotalMetrics(100)
            .setStillDeletedMetrics(relevantMetrics.size())
            .setProgress((100 - relevantMetrics.size()) / 100.0)
            .build();
        var proc = rollbackDeletedMetrics(params, progress);

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(100)
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());

        assertMetrics(irrelevantMetrics, deletedMetricsDao.metrics());
        assertMetrics(relevantMetrics, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void resumeDeletionWhenProgressInaccurate() {
        // arrange
        var params = params();

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());

        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        var progress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(false)
            .setTotalMetrics(100)
            .setStillDeletedMetrics(1)
            .setProgress(0.01)
            .build();
        var proc = rollbackDeletedMetrics(params, progress);

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(100)
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());

        assertMetrics(irrelevantMetrics, deletedMetricsDao.metrics());
        assertMetrics(relevantMetrics, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void onErrorSaveLatestProgress() throws Exception {
        // arrange
        retryConfig = retryConfig.withNumRetries(3);

        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(1000, 2000))
            .collect(toUnmodifiableList());

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(500, 1000))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(500, 1000))
            .collect(toUnmodifiableList());

        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        var lucky = new AtomicBoolean(true);
        var atLeastOnce = new CompletableFuture<>();
        getMetricsDao(params.getNumId()).beforeSupplier = () -> {
            if (lucky.getAndSet(false)) {
                return completedFuture(null);
            }

            return atLeastOnce.thenRun(() -> {
                throw Status.ABORTED.asRuntimeException();
            });
        };

        // act
        var future = proc.start();

        while (proc.progress().getProgress() == 0) {
            TimeUnit.MILLISECONDS.sleep(1);
        }
        atLeastOnce.completeAsync(() -> null);

        var status = future.thenApply(i -> Status.OK).exceptionally(Status::fromThrowable).join();

        // assert
        assertNotEquals(Status.Code.OK, status.getCode());

        var progress = proc.progress();
        assertFalse(progress.getComplete());
        assertEquals(relevantMetrics.size(), progress.getTotalMetrics());
        assertThat(progress.getStillDeletedMetrics(), lessThan(relevantMetrics.size()));
        assertThat(progress.getProgress(), allOf(greaterThan(0.0), lessThan(1.0)));
    }

    @Test
    public void retryOnDownstreamCallNotOkStatusCodes() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var metrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(1000)
            .collect(toList());
        shuffle(metrics);

        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), metrics);
        ensureShardReady(params.getNumId());

        var a = new AtomicBoolean();
        deletedMetricsDao.beforeSupplier =
            () -> a.getAndSet(true) && random().nextBoolean()
                ? completedFuture(null)
                : unavailable().get();

        // act
        proc.start().join();

        // assert
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(true)
            .setTotalMetrics(metrics.size())
            .setStillDeletedMetrics(0)
            .setProgress(1)
            .build();
        assertEquals(expectedProgress, proc.progress());

        assertMetrics(List.of(), deletedMetricsDao.metrics());
        assertMetrics(metrics, getMetricsDao(params.getNumId()).metrics());
    }

    @Test
    public void canceledOnClose() {
        // arrange
        var params = params();
        var proc = rollbackDeletedMetrics(params, RollbackDeletedMetricsProgress.getDefaultInstance());

        var metrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(500)
            .collect(toList());
        shuffle(metrics);

        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), metrics);
        ensureShardReady(params.getNumId());

        var calls = new AtomicInteger(2);
        deletedMetricsDao.beforeSupplier = () -> {
            if (calls.decrementAndGet() == 0) {
               proc.close();
            }

            return completedFuture(null);
        };

        // act
        var status = proc.start().thenApply(i -> Status.OK).exceptionally(Status::fromThrowable).join();

        // assert
        assertEquals(Status.Code.CANCELLED, status.getCode());
        assertFalse(proc.progress().getComplete());
    }

    @Test
    public void gracefulInterruption() {
        // arrange
        var params = params();

        var relevantMetrics = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(3000, 4000))
            .collect(toUnmodifiableList());

        var lo = random().nextInt(2500);
        var hi = lo + 5;
        var actuallyWrittenMetrics = relevantMetrics.subList(lo, hi);

        var irrelevantMetricsDiffOp = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());
        var irrelevantMetricsDiffNumId = Stream.generate(DeleteMetricsRandom::metric)
            .limit(random().nextInt(10, 20))
            .collect(toUnmodifiableList());

        var irrelevantMetrics = Streams.concat(
                irrelevantMetricsDiffOp.stream(),
                irrelevantMetricsDiffNumId.stream())
            .collect(toUnmodifiableList());

        var otherExistingMetrics = Stream.generate(() -> metricInSpShard(666))
            .limit(random().nextInt(3, 5));
        var alreadyExistingMetrics = Stream.concat(
                actuallyWrittenMetrics.stream(),
                otherExistingMetrics)
            .collect(toList());

        ensureMetricsInDao(params.getNumId(), alreadyExistingMetrics);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId(), relevantMetrics);
        ensureDeletedMetricsInDao(params.getOperationId() + "LOL", params.getNumId(), irrelevantMetricsDiffOp);
        ensureDeletedMetricsInDao(params.getOperationId(), params.getNumId() + 1000, irrelevantMetricsDiffNumId);
        ensureShardReady(params.getNumId());

        var progress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(false)
            .setTotalMetrics(relevantMetrics.size())
            .setStillDeletedMetrics(relevantMetrics.size())
            .setProgress(0)
            .build();
        var proc = rollbackDeletedMetrics(params, progress, true);

        // act
        proc.start().join();

        // assert
        var actualProgress = proc.progress();
        var expectedProgress = RollbackDeletedMetricsProgress.newBuilder()
            .setComplete(false)
            .setTotalMetrics(relevantMetrics.size())
            .setStillDeletedMetrics(relevantMetrics.size() - 5)
            .build();
        assertEquals(expectedProgress, actualProgress.toBuilder().clearProgress().build());
        assertThat(actualProgress.getProgress(), greaterThan(0.0));

        var expectedDeleted = Stream.of(
                relevantMetrics.subList(0, lo),
                relevantMetrics.subList(hi, relevantMetrics.size()),
                irrelevantMetrics)
            .flatMap(Collection::stream)
            .collect(toList());
        assertMetrics(expectedDeleted, deletedMetricsDao.metrics());

        assertMetrics(alreadyExistingMetrics, getMetricsDao(params.getNumId()).metrics());
    }

    private void ensureDeletedMetricsInDao(String operationId, int numId, Collection<CoremonMetric> metrics) {
        deletedMetricsDao.putAll(operationId, numId, metrics);
    }

    private void ensureMetricsInDao(int numId, List<CoremonMetric> metrics) {
        getMetricsDao(numId).add(metrics);
    }

    private InMemoryMetricsDao getMetricsDao(int numId) {
        return metricsDaoFactory.create(numId, Labels.allocator);
    }

    private void ensureShardReady(int numId) {
        metricsDaoFactory.resumeShardInit(numId);
        shardResolver.resolveShard(numId).awaitReady();
    }

    private RollbackDeletedMetrics rollbackDeletedMetrics(
        DeleteMetricsParams params,
        RollbackDeletedMetricsProgress progress)
    {
        return rollbackDeletedMetrics(params, progress, false);
    }

    private RollbackDeletedMetrics rollbackDeletedMetrics(
        DeleteMetricsParams params,
        RollbackDeletedMetricsProgress progress,
        boolean interrupted)
    {
        return new RollbackDeletedMetrics(
            retryConfig,
            deletedMetricsDao,
            shardResolver,
            commonPool(),
            params,
            progress,
            interrupted,
            TEST_FIND_METRICS_LIMIT);
    }

    private static DeleteMetricsParams params() {
        return DeleteMetricsRandom.params(MAX_NUM_ID);
    }

    private static ThreadLocalRandom random() {
        return ThreadLocalRandom.current();
    }

    private static Supplier<CompletableFuture<?>> unavailable() {
        return () -> failedFuture(Status.UNAVAILABLE.asRuntimeException());
    }

}
