package ru.yandex.solomon.core.conf.flags;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.solomon.core.db.dao.ClusterFlagsDao;
import ru.yandex.solomon.core.db.dao.EntityFlagsDao;
import ru.yandex.solomon.core.db.dao.ProjectFlagsDao;
import ru.yandex.solomon.core.db.dao.ServiceFlagsDao;
import ru.yandex.solomon.core.db.dao.ShardFlagsDao;
import ru.yandex.solomon.flags.FeatureFlag;
import ru.yandex.solomon.flags.FeatureFlagsConfig;
import ru.yandex.solomon.flags.FeatureFlagsListener;
import ru.yandex.solomon.flags.FeatureFlagsProject;
import ru.yandex.solomon.staffOnly.manager.special.InstantMillis;
import ru.yandex.solomon.util.file.FileStorage;
import ru.yandex.solomon.util.io.IoFunction;

/**
 * @author Vladimir Gordiychuk
 */
public class FeatureFlagsWatcher implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(FeatureFlagsWatcher.class);

    private static final long RELOAD_INTERVAL = TimeUnit.MINUTES.toMillis(1L);
    private static final String PROJECT_STATE = "project.flags.state";
    private static final String CLUSTER_STATE = "cluster.flags.state";
    private static final String SERVICE_STATE = "service.flags.state";
    private static final String SHARD_STATE = "shard.flags.state";

    private static final String DEFAULT_ID = "";

    private final FileStorage storage;
    private final ScheduledExecutorService timer;
    private final ProjectFlagsDao projectDao;
    private final ClusterFlagsDao clusterDao;
    private final ServiceFlagsDao serviceDao;
    private final ShardFlagsDao shardDao;
    private final FeatureFlagsListener consumer;

    private volatile Future<?> scheduled;
    private volatile boolean closed;

    private volatile Throwable lastError;
    @InstantMillis
    private volatile long lastErrorInstant;

    public FeatureFlagsWatcher(
            FileStorage storage,
            ScheduledExecutorService timer,
            ProjectFlagsDao projectDao,
            ClusterFlagsDao clusterDao,
            ServiceFlagsDao serviceDao,
            ShardFlagsDao shardDao,
            FeatureFlagsListener consumer)
    {
        this.storage = storage;
        this.timer = timer;
        this.projectDao = projectDao;
        this.clusterDao = clusterDao;
        this.serviceDao = serviceDao;
        this.shardDao = shardDao;
        this.consumer = consumer;
        this.scheduled = timer.submit(this::runScheduled);
    }

    private void schedule(long minDelayMillis, long maxDelayMillis) {
        if (closed) {
            return;
        }

        final long rndDelayMillis = ThreadLocalRandom.current().nextLong(minDelayMillis, maxDelayMillis);
        if (rndDelayMillis <= 0) {
            timer.submit(this::runScheduled);
        } else {
            scheduled = timer.schedule(this::runScheduled, rndDelayMillis, TimeUnit.MILLISECONDS);
        }
    }

    private void runScheduled() {
        if (closed) {
            return;
        }

        try {
            var task = new ReloadTask();
            task.run().whenComplete((flags, e) -> {
                if (e != null) {
                    failedIteration(e);
                } else {
                    notifyConsumer(flags);
                    schedule(RELOAD_INTERVAL, RELOAD_INTERVAL * 2);
                }
            });
        } catch (Throwable e) {
            failedIteration(e);
        }
    }

    private void notifyConsumer(FeatureFlagsConfig flags) {
        try {
            consumer.onConfigurationLoad(flags);
        } catch (Throwable e) {
            logger.error("Consume actual feature flags snapshot failed", e);
        }
    }

    private void failedIteration(Throwable e) {
        lastError = e;
        lastErrorInstant = System.currentTimeMillis();
        logger.error("Reload feature flags failed", e);
        schedule(RELOAD_INTERVAL / 2, RELOAD_INTERVAL);
    }

    @Override
    public void close() {
        closed = true;
        var copy = scheduled;
        if (copy != null) {
            copy.cancel(false);
        }
    }

    private interface FlagConsumer {
        void consume(FeatureFlagsProject project, String id, FeatureFlag flag, boolean value);
    }

    private class ReloadTask {
        private final FeatureFlagsConfig configFlags;

        public ReloadTask() {
            this.configFlags = new FeatureFlagsConfig();
        }

        public CompletableFuture<FeatureFlagsConfig> run() {
            return loadProjects()
                    .thenCompose(ignore -> loadServices())
                    .thenCompose(ignore -> loadEntries(clusterDao, FeatureFlagsProject::addClusterFlag, CLUSTER_STATE))
                    .thenCompose(ignore -> loadEntries(shardDao, FeatureFlagsProject::addShardFlag, SHARD_STATE))
                    .thenApply(ignore -> configFlags);
        }

        private CompletableFuture<Void> loadProjects() {
            return projectDao.findAll()
                    .exceptionally(e -> loadProjectsFlagsFromFile(PROJECT_STATE, e))
                    .thenAccept(records -> {
                        for (var record : records) {
                            var flag = FeatureFlag.byName(record.flag);
                            if (flag == null) {
                                continue;
                            }

                            if (DEFAULT_ID.equals(record.projectId)) {
                                configFlags.addDefaultFlag(flag, record.value);
                            } else {
                                configFlags.addProjectFlag(record.projectId, flag, record.value);
                            }
                        }
                        saveProjectsFlagsToFile(PROJECT_STATE, records);
                    });
        }

        private CompletableFuture<Void> loadServices() {
            return serviceDao.findAll()
                    .exceptionally(e -> loadEntityFlagsFromFile(SERVICE_STATE, e))
                    .thenAccept(records -> {
                        for (var record : records) {
                            var flag = FeatureFlag.byName(record.flag);
                            if (flag == null) {
                                continue;
                            }

                            if (DEFAULT_ID.equals(record.projectId)) {
                                configFlags.addServiceProviderFlag(record.id, flag, record.value);
                            } else {
                                var project = configFlags.getProjectFlags(record.projectId);
                                project.addServiceFlag(record.id, flag, record.value);
                            }
                        }
                        saveEntityFlagsToFile(SERVICE_STATE, records);
                    });
        }

        private CompletableFuture<Void> loadEntries(EntityFlagsDao dao, FlagConsumer consumer, String stateFile) {
            return dao.findAll()
                    .exceptionally(e -> loadEntityFlagsFromFile(stateFile, e))
                    .thenAccept(records -> {
                        for (var record : records) {
                            var flag = FeatureFlag.byName(record.flag);
                            if (flag == null) {
                                continue;
                            }
                            var project = configFlags.getProjectFlags(record.projectId);
                            consumer.consume(project, record.id, flag, record.value);
                        }
                        saveEntityFlagsToFile(stateFile, records);
                    });
        }

        private void saveProjectsFlagsToFile(String fileName, List<ProjectFlagsDao.Record> records) {
            saveStateToFile(fileName, records, record -> {
                return record.flag + ";" + record.projectId + ";" + record.value;
            });
        }

        private List<ProjectFlagsDao.Record> loadProjectsFlagsFromFile(String fileName, Throwable cause) {
            return loadStateFromFile(fileName, cause, s -> {
                var parts = s.split(";");
                if (parts.length != 3) {
                    return null;
                }
                return new ProjectFlagsDao.Record(parts[0], parts[1], Boolean.parseBoolean(parts[2]));
            });
        }

        private void saveEntityFlagsToFile(String fileName, List<EntityFlagsDao.Record> records) {
            saveStateToFile(fileName, records, record -> {
                return record.flag + ";" + record.projectId + ";" + record.id + ";" + record.value;
            });
        }

        private List<EntityFlagsDao.Record> loadEntityFlagsFromFile(String fileName, Throwable cause) {
            return loadStateFromFile(fileName, cause, s -> {
                var parts = s.split(";");
                if (parts.length != 4) {
                    return null;
                }
                return new EntityFlagsDao.Record(parts[0], parts[1], parts[2], Boolean.parseBoolean(parts[3]));
            });
        }

        private <T> void saveStateToFile(String fileName, List<T> records, IoFunction<T, String> serialize) {
            try {
                storage.saveValues(fileName, records, serialize);
            } catch (Throwable e) {
                logger.error("Error while writing state file " + fileName, e);
            }
        }

        private <T> List<T> loadStateFromFile(String fileName, Throwable cause, IoFunction<String, T> deserialize) {
            try {
                var records = storage.loadValues(fileName, deserialize);
                if (records == null) {
                    throw new RuntimeException(cause);
                }

                logger.error("Failed load flags, fallback to file {}", fileName, cause);
                return records;
            } catch (Throwable e) {
                logger.error("Error while reading state file " + fileName, e);
                cause.addSuppressed(e);
                throw new RuntimeException(cause);
            }
        }
    }
}
