package ru.yandex.solomon.tool.migration.kv;

import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import javax.annotation.ParametersAreNonnullByDefault;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Throwables;
import com.yandex.ydb.table.TableClient;
import io.netty.channel.EventLoopGroup;

import ru.yandex.kikimr.client.KikimrAnyResponseException;
import ru.yandex.kikimr.client.KikimrGrpcTransport;
import ru.yandex.kikimr.client.KikimrTransport;
import ru.yandex.kikimr.proto.FlatTxScheme;
import ru.yandex.kikimr.proto.Msgbus;
import ru.yandex.misc.cmdline.CmdArgsChief;
import ru.yandex.misc.cmdline.CmdLineArgs;
import ru.yandex.misc.cmdline.Parameter;
import ru.yandex.misc.thread.factory.DaemonThreadFactory;
import ru.yandex.misc.thread.factory.ThreadNameThreadFactory;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.core.db.dao.ShardsDao;
import ru.yandex.solomon.core.db.dao.ydb.YdbShardsDao;
import ru.yandex.solomon.core.db.model.Shard;
import ru.yandex.solomon.coremon.meta.CoremonMetric;
import ru.yandex.solomon.coremon.meta.CoremonMetricArray;
import ru.yandex.solomon.coremon.meta.FileCoremonMetric;
import ru.yandex.solomon.coremon.meta.db.MetricsDao;
import ru.yandex.solomon.coremon.meta.db.MetricsDaoFactory;
import ru.yandex.solomon.coremon.meta.db.ydb.YdbMetricsDaoFactory;
import ru.yandex.solomon.labels.shard.ShardKey;
import ru.yandex.solomon.main.logger.LoggerConfigurationUtils;
import ru.yandex.solomon.selfmon.ng.JvmMon;
import ru.yandex.solomon.selfmon.ng.ProcSelfMon;
import ru.yandex.solomon.tool.YdbHelper;
import ru.yandex.solomon.tool.cfg.SolomonCluster;
import ru.yandex.solomon.tool.migration.kv.MetabaseTransferMetrics.ShardMetrics;
import ru.yandex.solomon.util.ExceptionUtils;
import ru.yandex.solomon.util.NettyUtils;
import ru.yandex.solomon.util.PropertyInitializer;
import ru.yandex.solomon.util.actors.AsyncActorBody;
import ru.yandex.solomon.util.actors.AsyncActorRunner;
import ru.yandex.solomon.util.future.RetryCompletableFuture;
import ru.yandex.solomon.util.future.RetryConfig;
import ru.yandex.solomon.util.host.HostUtils;
import ru.yandex.solomon.util.time.DurationUtils;

/**
 * @author Maksim Leonov (nohttp@)
 */
@ParametersAreNonnullByDefault
public class MetabaseTransfer {
    private static final int MAX_SHARDS_IN_FLIGHT = 10;
    private static final int MAX_WRITE_IN_FLIGHT = 10;
    private static final int MAX_WRITE_SENSORS_BATCH = 2000;

    private final CmdlineArgs args;
    private final ExecutorService executor;
    private final EventLoopGroup io;
    private final MetricsDaoFactory targetDao;
    private final MetricsDaoFactory sourceDao;
    private final ShardsDao shardsDao;
    private final MetabaseTransferMetrics metrics;

    private final RetryConfig retryConfig;

    public MetabaseTransfer(CmdlineArgs args) {
        this.args = args;
        if (EnumSet.of(SolomonCluster.PROD_FETCHER_VLA, SolomonCluster.PROD_FETCHER_SAS).contains(args.targetClusterId)) {
            throw new RuntimeException("Dangerous");
        }

        MetricRegistry registry = new MetricRegistry(Labels.of("host", HostUtils.getShortName()));
        metrics = new MetabaseTransferMetrics(registry
                .subRegistry("from", this.args.sourceClusterId.name())
                .subRegistry("to", this.args.targetClusterId.name()));
        JvmMon.addAllMetrics(registry);
        ProcSelfMon.addCpuTimeMetrics(registry);
        ProcSelfMon.addMemoryMetrics(registry);
        ProcSelfMon.addThreadsMetrics(registry);
        MetricsPushClient.create().schedulePush(
                new ShardKey("solomon", "push", "push"),
                registry);
        executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2,
                new DaemonThreadFactory(new ThreadNameThreadFactory(MetabaseTransfer.class)));
        io = NettyUtils.createEventLoopGroup(this.getClass().getSimpleName(), 4);
        targetDao = createMetricFlatDao(args.targetClusterId, registry);
        sourceDao = createMetricFlatDao(args.sourceClusterId, registry);

        shardsDao = createShardsDao();
        retryConfig = RetryConfig.DEFAULT
                .withNumRetries(30)
                .withDelay(1_000)
                .withStats((timeSpentMillis, cause) -> {
                    System.out.println("Failed, retrying...\n" + Throwables.getStackTraceAsString(cause));
                    metrics.retryCount.inc();
                });
    }

    public static void main(String[] args) {
        try {
            PropertyInitializer.init();
            LoggerConfigurationUtils.disableLogger();

            CmdlineArgs parsed = CmdArgsChief.parse(CmdlineArgs.class, args);
            new MetabaseTransfer(parsed).run();
            TimeUnit.SECONDS.sleep(30L);
            System.exit(0);
        } catch (Throwable e) {
            ExceptionUtils.uncaughtException(e);
        }
    }

    private void run() {
        System.out.println("fqdn: " + HostUtils.getFqdn());
        Predicate<Shard> shardFilter;
        if (args.shards.isEmpty()) {
            System.out.println("No shards to process!");
            return;
        } else if (args.shards.size() == 1) {
            if (args.shards.get(0).equals("all")) {
                shardFilter = shard -> true;
            } else if (args.shards.get(0).equals("local")) {
                throw new UnsupportedOperationException("TODO: implement shard filtration by hostname");
            } else {
                shardFilter = shard -> shard.getId().equals(args.shards.get(0));
            }
        } else {
            Set<String> shardIds = new HashSet<>(args.shards);
            shardFilter = shard -> shardIds.contains(shard.getId());
        }

        Deque<Shard> shards = shardsDao.findAll().join()
            .stream()
            .filter(shardFilter)
            .collect(Collectors.toCollection(ArrayDeque::new));
        migrateShards(shards);
    }

    private void migrateShards(Deque<Shard> shards) {
        int size = shards.size();
        System.out.println(size + " shards to process");
        metrics.shardsCount.set(size);

        AsyncActorBody body = () -> {
            Shard shard = shards.poll();
            if (shard == null) {
                return CompletableFuture.completedFuture(AsyncActorBody.DONE_MARKER);
            }

            System.out.println("Processed " + this.metrics.shardsProcessed.get() + "/" + size + " shards");
            System.out.println("Migrating " + shard.getId() + " => " + args.targetClusterId);
            return RetryCompletableFuture.runWithRetries(
                    () -> migrateShard(shard, args.maxAgeDays),
                    retryConfig)
                    .thenAccept(ignore -> {
                        metrics.shardsProcessed.inc();
                    });
        };

        AsyncActorRunner actorRunner = new AsyncActorRunner(body, executor, MAX_SHARDS_IN_FLIGHT);
        actorRunner.start().join();
        System.out.println("Processed " + this.metrics.shardsProcessed.get() + "/" + size + " shards");
    }

    private MetricsDaoFactory createMetricFlatDao(SolomonCluster cluster, MetricRegistry registry) {
        KikimrTransport transport = new KikimrGrpcTransport(
                cluster.addressesKikimrGrpc(),
                26 << 20,
                Duration.ofSeconds(15),
                Duration.ofSeconds(180),
                io,
                executor);

        TableClient tableClient = YdbHelper.createTableClient(cluster);
        return YdbMetricsDaoFactory.forReadWrite(
            tableClient,
            cluster.kikimrRootPath() + "/Solomon/metrics",
            registry);
    }

    private ShardsDao createShardsDao() {
        SolomonCluster cluster = SolomonCluster.PROD_FRONT;
        TableClient tableClient = YdbHelper.createTableClient(cluster);
        return new YdbShardsDao(tableClient,  cluster.kikimrRootPath() + "/Solomon/Config/V2/Shard", new ObjectMapper(), ForkJoinPool.commonPool());
    }

    private CompletableFuture<Void> migrateShard(Shard shard, int maxAgeDays) {
        return new Task(shard, maxAgeDays).run();
    }

    @CmdLineArgs
    public static class CmdlineArgs {
        @Parameter(names = "--source", description = "Source stockpile cluster id")
        public SolomonCluster sourceClusterId;
        @Parameter(names = "--target", description = "Target stockpile cluster id")
        public SolomonCluster targetClusterId;

        @Parameter(names = "--max-age", description = "Max metric age to transfer (in days). If zero, all metrics will be transferred")
        public Integer maxAgeDays = 0;

        @Parameter(description = "solomon shards to migrate [also supports 'all' and 'local']")
        public List<String> shards;
    }

    private class Task {
        private final String shardId;
        private final MetricsDao source;
        private final MetricsDao target;
        private final int maxAgeDays;
        private final ConcurrentLinkedQueue<CoremonMetric> queue = new ConcurrentLinkedQueue<>();
        private final CompletableFuture<Void> doneFuture = new CompletableFuture<>();
        private volatile boolean doneLoad = false;

        private final ReentrantLock lock = new ReentrantLock();
        private final Condition nonEmptyOrDone = lock.newCondition();
        private final ShardMetrics shardMetrics;

        public Task(Shard shard, int maxAgeDays) {
            this.shardId = shard.getId();
            this.maxAgeDays = maxAgeDays;
            source = sourceDao.create(shard.getNumId());
            target = targetDao.create(shard.getNumId());
            shardMetrics = metrics.getShardMetrics(shardId);
            if (shardMetrics.startedAtNanos == 0) {
                shardMetrics.startedAtNanos = System.nanoTime();
            }
        }

        public CompletableFuture<Void> run() {
            System.out.println(msgPrefix() + " starting migrate");
            source.getMetricCount()
                    .thenAccept(count -> {
                        System.out.println(msgPrefix() + " estimated rows count " + count);
                        shardMetrics.metricsTotal.set(count);
                    })
                    .thenCompose(ignore -> CompletableFuture.allOf(load(), write()))
                    .whenComplete((ignore, e) -> {
                        if (e != null) {
                            System.out.println(msgPrefix() + " failed processing");
                            doneFuture.completeExceptionally(new RuntimeException("shardId: " + shardId, e));
                        } else {
                            System.out.println(msgPrefix() + " finish processing");
                            shardMetrics.completedAtNanos = System.nanoTime();
                            doneFuture.complete(null);
                        }
                    });

            return doneFuture;
        }

        private String timeSpend() {
            long spendNanos = System.nanoTime() - shardMetrics.startedAtNanos;
            long spendMillis = TimeUnit.NANOSECONDS.toMillis(spendNanos);
            if (spendMillis < 1000) {
                return spendMillis + "ms";
            }
            return DurationUtils.formatDurationMillis(spendMillis);
        }

        private CompletableFuture<Void> load() {
            return RetryCompletableFuture.runWithRetries(this::startLoad, retryConfig);
        }

        private CompletableFuture<Void> startLoad() {
            CompletableFuture<Void> future = new CompletableFuture<>();
            executor.execute(() -> {
                shardMetrics.metricsLoad.reset();
                shardMetrics.metricsWrite.reset();
                queue.clear();
                System.out.println(msgPrefix() + " start loading");
                try {
                    final long loaded;
                    if (maxAgeDays == 0) {
                        loaded = source.findMetrics(chunk -> {
                            shardMetrics.metricsLoad.inc();
                            for (int i = 0; i < chunk.size(); i++) {
                                addToQueue(new FileCoremonMetric(chunk.get(i)));
                            }
                        }, OptionalLong.empty()).join();
                    } else {
                        long maxAgeSeconds = Instant.now().minus(maxAgeDays, ChronoUnit.DAYS).getEpochSecond();
                        loaded = source.findMetrics(chunk -> {
                            shardMetrics.metricsLoad.inc();
                            for (int i = 0; i < chunk.size(); i++) {
                                CoremonMetric metric = chunk.get(i);
                                if (metric.getCreatedAtSeconds() >= maxAgeSeconds) {
                                    addToQueue(new FileCoremonMetric(metric));
                                } else {
                                    shardMetrics.metricsWrite.inc();
                                }
                            }
                        }, OptionalLong.empty()).join();
                    }
                    shardMetrics.metricsTotal.set(loaded);
                    System.out.println(msgPrefix() + " complete loading");
                    future.complete(null);
                } catch (Throwable e) {
                    if (isPathDoesNotExist(e)) {
                        // shard doesn't have metrics at all
                        future.complete(null);
                    } else {
                        future.completeExceptionally(new RuntimeException(msgPrefix() + " load failed ", e));
                    }
                } finally {
                    doneLoad = true;
                    signalToWriter();
                }
            });
            return future;
        }

        private String msgPrefix() {
            return shardId + " (" + timeSpend() + "): ";
        }

        private CompletableFuture<Void> write() {
            System.out.println(msgPrefix() + " start writing");
            AsyncActorBody body = () -> {
                boolean doneLoad = this.doneLoad;
                CoremonMetricArray metrics = pull();
                if (metrics.isEmpty() && doneLoad) {
                    System.out.println(msgPrefix() + " complete writing");
                    metrics.close();
                    return CompletableFuture.completedFuture(AsyncActorBody.DONE_MARKER);
                }

                if (metrics.isEmpty()) {
                    metrics.close();
                    return CompletableFuture.completedFuture(null);
                }

                int count = metrics.size();
                return replaceWithRetry(metrics)
                        .thenApply(unit -> {
                            shardMetrics.metricsWrite.add(count);
                            return unit;
                        });
            };

            AsyncActorRunner actorRunner = new AsyncActorRunner(body, executor, MAX_WRITE_IN_FLIGHT);
            return actorRunner.start();
        }

        private CompletableFuture<Void> replaceWithRetry(CoremonMetricArray metrics) {
            return RetryCompletableFuture.runWithRetries(() -> {
                try {
                    return target.replaceMetrics(metrics).thenApply(m -> {
                        m.closeSilent();
                        return null;
                    });
                } finally {
                    metrics.close();
                }
            }, retryConfig);
        }

        private void addToQueue(CoremonMetric metric) {
            boolean empty = queue.isEmpty();
            queue.add(metric);
            if (empty) {
                signalToWriter();
            }
        }

        private CoremonMetricArray pull() {
            CoremonMetricArray result = new CoremonMetricArray(MAX_WRITE_SENSORS_BATCH);
            drainTo(result);

            if (!result.isEmpty()) {
                return result;
            }

            lock.lock();
            try {
                while (queue.isEmpty() && !doneLoad) {
                    nonEmptyOrDone.awaitUninterruptibly();
                }

                drainTo(result);
                return result;
            } finally {
                lock.unlock();
            }
        }

        private void drainTo(CoremonMetricArray metrics) {
            while (metrics.size() < MAX_WRITE_SENSORS_BATCH) {
                CoremonMetric chunk = queue.poll();
                if (chunk == null) {
                    return;
                }
                metrics.add(chunk);
            }
        }

        private void signalToWriter() {
            lock.lock();
            try {
                nonEmptyOrDone.signalAll();
            } finally {
                lock.unlock();
            }
        }

        private boolean isPathDoesNotExist(Throwable t) {
            Throwable cause = t;
            while (cause != null) {
                if (cause instanceof KikimrAnyResponseException) {
                    KikimrAnyResponseException anyResponseException = (KikimrAnyResponseException) cause;
                    if (anyResponseException.getResponse() instanceof Msgbus.TResponse) {
                        Msgbus.TResponse response = (Msgbus.TResponse) anyResponseException.getResponse();

                        if (FlatTxScheme.EStatus.StatusPathDoesNotExist.getNumber() == response.getStatus()) {
                            return true;
                        } else {
                            return anyResponseException.getMessage().contains("PathErrorUnknown");
                        }
                    }
                    return false;
                }
                cause = cause.getCause();
            }
            return false;
        }
    }
}
