package ru.yandex.solomon.core.conf.flags;

import java.nio.file.Path;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;

import ru.yandex.solomon.core.db.dao.EntityFlagsDao;
import ru.yandex.solomon.core.db.dao.EntityFlagsDao.Record;
import ru.yandex.solomon.core.db.dao.ProjectFlagsDao;
import ru.yandex.solomon.core.db.dao.memory.InMemoryClusterFlagsDao;
import ru.yandex.solomon.core.db.dao.memory.InMemoryProjectFlagsDao;
import ru.yandex.solomon.core.db.dao.memory.InMemoryServiceFlagsDao;
import ru.yandex.solomon.core.db.dao.memory.InMemoryShardFlagsDao;
import ru.yandex.solomon.flags.FeatureFlag;
import ru.yandex.solomon.flags.FeatureFlags;
import ru.yandex.solomon.flags.FeatureFlagsConfig;
import ru.yandex.solomon.flags.FeatureFlagsMatchers;
import ru.yandex.solomon.ut.ManualClock;
import ru.yandex.solomon.ut.ManualScheduledExecutorService;
import ru.yandex.solomon.util.file.SimpleFileStorage;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;

/**
 * @author Vladimir Gordiychuk
 */
public class FeatureFlagsWatcherTest {
    @Rule
    public Timeout globalTimeout = Timeout.builder()
            .withTimeout(3, TimeUnit.SECONDS)
            .withLookingForStuckThread(true)
            .build();

    @Rule
    public TemporaryFolder tmp = new TemporaryFolder();
    @Rule
    public TestName testName = new TestName();

    private volatile FeatureFlagsConfig flags;
    private volatile CountDownLatch sync;

    private ManualClock clock;
    private Path storage;
    private ManualScheduledExecutorService timer;
    private InMemoryProjectFlagsDao projectsDao;
    private InMemoryClusterFlagsDao clustersDao;
    private InMemoryServiceFlagsDao serviceDao;
    private InMemoryShardFlagsDao shardsDao;
    private FeatureFlagsWatcher watcher;

    @Before
    public void setUp() throws Exception {
        clock = new ManualClock();
        storage = tmp.newFolder(testName.getMethodName()).toPath();
        timer = new ManualScheduledExecutorService(1, clock);
        sync = new CountDownLatch(1);
        projectsDao = new InMemoryProjectFlagsDao();
        clustersDao = new InMemoryClusterFlagsDao();
        serviceDao = new InMemoryServiceFlagsDao();
        shardsDao = new InMemoryShardFlagsDao();
        watcher = new FeatureFlagsWatcher(
            new SimpleFileStorage(storage),
            timer,
            projectsDao,
            clustersDao,
            serviceDao,
            shardsDao,
            snapshot -> {
                flags = snapshot;
                sync.countDown();
            });
    }

    @After
    public void tearDown() throws Exception {
        timer.shutdownNow();
        watcher.close();
    }

    @Test
    public void emptyState() throws InterruptedException {
        awaitReload();
        var v1 = flags;
        assertNotNull(v1);

        awaitReload();
        var v2 = flags;
        assertNotSame(v1, v2);

        awaitReload();
        var v3 = flags;
        assertNotSame(v2, v3);
    }

    @Test
    public void perShardFlag() throws InterruptedException {
        shardsDao.upsert(new EntityFlagsDao.Record(FeatureFlag.TEST.name(), "projectId", "shardId", true)).join();

        awaitReload();

        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shardId"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shardId2"));

        shardsDao.deleteOne(FeatureFlag.TEST.name(), "projectId", "shardId").join();
        awaitReload();

        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shardId"));
    }

    @Test
    public void perProjectFlag() throws InterruptedException {
        projectsDao.upsert(new ProjectFlagsDao.Record(FeatureFlag.TEST.name(), "vlgo", true)).join();

        awaitReload();
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("vlgo").shardId("shard1"));
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("vlgo").shardId("shard2"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("vlgo2").shardId("shard3"));

        projectsDao.deleteOne(FeatureFlag.TEST.name(), "vlgo").join();
        awaitReload();
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("vlgo").shardId("shard1"));
    }

    @Test
    public void perAllProjectFlag() throws InterruptedException {
        awaitReload();
        projectsDao.upsert(new ProjectFlagsDao.Record(FeatureFlag.TEST.name(), "", true)).join();

        awaitReload();
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("vlgo").shardId("shard1"));
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shard1"));


        projectsDao.upsert(new ProjectFlagsDao.Record(FeatureFlag.TEST.name(), "", false)).join();
        awaitReload();

        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("vlgo").shardId("shard1"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shard1"));
    }

    @Test
    public void perServiceProviderFlag() throws InterruptedException {
        awaitReload();
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("alice").serviceProvider("myService"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("bob").serviceProvider("myService"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("eva").serviceProvider("anotherService"));

        serviceDao.upsert(new Record(FeatureFlag.TEST.name(), "", "myService", true));
        awaitReload();

        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("alice").serviceProvider("myService"));
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("bob").serviceProvider("myService"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("eva").serviceProvider("anotherService"));

        serviceDao.deleteOne(FeatureFlag.TEST.name(), "", "myService");
        awaitReload();
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("alice").serviceProvider("myService"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("bob").serviceProvider("myService"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("eva").serviceProvider("anotherService"));
    }

    @Test
    public void ignoreInvalidFlags() throws InterruptedException {
        awaitReload();
        shardsDao.upsert(new EntityFlagsDao.Record("noise", "projectId", "shardId", true)).join();
        shardsDao.upsert(new EntityFlagsDao.Record(FeatureFlag.TEST.name(), "projectId", "shardId", true)).join();

        awaitReload();
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shardId"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shardId2"));

        shardsDao.deleteOne(FeatureFlag.TEST.name(), "projectId", "shardId").join();
        awaitReload();
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("projectId").shardId("shardId"));
    }

    @Test
    public void fallbackProjectFlags() throws InterruptedException {
        awaitReload();

        projectsDao.upsert(new ProjectFlagsDao.Record(FeatureFlag.TEST.name(), "", true)).join();
        projectsDao.upsert(new ProjectFlagsDao.Record(FeatureFlag.TEST.name(), "solomon", false)).join();

        awaitReload();
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("test"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("solomon"));

        var v1 = flags;
        projectsDao.deleteOne(FeatureFlag.TEST.name(), "solomon").join();
        projectsDao.setFailure(new RuntimeException("Hi"));

        awaitReload();
        var v2 = flags;
        assertNotSame(v1, v2);

        // we lost latest delete because dao unavailable
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("test"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("solomon"));
        projectsDao.setFailure(null);

        awaitReload();
        // state sync with dao now
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("solomon"));
    }

    @Test
    public void fallbackShardFlags() throws InterruptedException {
        awaitReload();

        shardsDao.upsert(new EntityFlagsDao.Record(FeatureFlag.TEST.name(), "solomon", "test", true)).join();
        shardsDao.upsert(new EntityFlagsDao.Record(FeatureFlag.TEST.name(), "solomon", "test2", false)).join();

        awaitReload();
        assertHasFlag(FeatureFlag.TEST, new Resolve().shardId("test"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().shardId("test2"));

        var v1 = flags;
        shardsDao.deleteOne(FeatureFlag.TEST.name(), "solomon", "test").join();
        shardsDao.setFailure(new RuntimeException("Hi"));

        awaitReload();
        var v2 = flags;
        assertNotSame(v1, v2);

        // we lost latest delete because dao unavailable
        assertHasFlag(FeatureFlag.TEST, new Resolve().projectId("solomon").shardId("test"));
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("solomon").shardId("test2"));
        shardsDao.setFailure(null);

        awaitReload();
        // state sync with dao now
        assertHasNotFlag(FeatureFlag.TEST, new Resolve().projectId("solomon").shardId("test"));
    }

    private void assertHasFlag(FeatureFlag flag, Resolve resolve) {
        FeatureFlags result = new FeatureFlags();
        this.flags.combineFlags(result, resolve.projectId, resolve.shardId, resolve.clusterId, resolve.serviceId, resolve.serviceProvider);
        FeatureFlagsMatchers.assertHasFlag(result, flag);
    }

    private void assertHasNotFlag(FeatureFlag flag, Resolve resolve) {
        FeatureFlags result = new FeatureFlags();
        this.flags.combineFlags(result, resolve.projectId, resolve.shardId, resolve.clusterId, resolve.serviceId, resolve.serviceProvider);
        FeatureFlagsMatchers.assertHasNotFlag(result, flag);
    }

    private void awaitReload() throws InterruptedException {
        for (int index = 0; index < 2; index++) {
            var copy = new CountDownLatch(1);
            sync = copy;
            while (!copy.await(5, TimeUnit.MILLISECONDS)) {
                clock.passedTime(1, TimeUnit.MINUTES);
            }
        }
    }

    private static class Resolve {
        String projectId = "solomon";
        String shardId = "shardId";
        String clusterId = "clusterId";
        String serviceId = "serviceId";
        String serviceProvider = "";

        public Resolve projectId(String projectId) {
            this.projectId = projectId;
            return this;
        }

        public Resolve shardId(String shardId) {
            this.shardId = shardId;
            return this;
        }

        public Resolve clusterId(String clusterId) {
            this.clusterId = clusterId;
            return this;
        }

        public Resolve serviceId(String serviceId) {
            this.serviceId = serviceId;
            return this;
        }

        public Resolve serviceProvider(String serviceProvider) {
            this.serviceProvider = serviceProvider;
            return this;
        }
    }
}
