package ru.yandex.solomon.gateway.shardHealth;

import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Import;
import org.springframework.stereotype.Component;

import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monlib.metrics.meter.ExpMovingAverage;
import ru.yandex.monlib.metrics.meter.TickMixin;
import ru.yandex.salmon.fetcher.proto.FetcherApiProto.FetcherShardHealth;
import ru.yandex.salmon.fetcher.proto.FetcherApiProto.ShardsHealthResponse;
import ru.yandex.solomon.config.thread.ThreadPoolProvider;
import ru.yandex.solomon.fetcher.client.FetcherClient;
import ru.yandex.solomon.fetcher.client.FetcherClientContext;


/**
 * @author Oleg Baryshnikov
 */
@Component
@Import({
    FetcherClientContext.class,
})
@ParametersAreNonnullByDefault
public class ShardHealthChecker {

    private static final Duration ITERATION_DELAY = Duration.ofMinutes(1);
    private static final Duration START_DELAY = Duration.ofMinutes(2);

    private static final Logger logger = LoggerFactory.getLogger(ShardHealthChecker.class);

    private final FetcherClient fetcherClient;
    private final ScheduledExecutorService executor;

    private final ConcurrentHashMap<Integer, ShardHealthMeters> shardHealthById = new ConcurrentHashMap<>();

    @Autowired
    public ShardHealthChecker(FetcherClient fetcherClient, ThreadPoolProvider threadPoolProvider) {
        this.fetcherClient = fetcherClient;
        this.executor = threadPoolProvider.getSchedulerExecutorService();

        long startDelayMillis = System.currentTimeMillis() + START_DELAY.toMillis();
        long alignedStartDelayMillis = ITERATION_DELAY.toMillis() - (startDelayMillis / ITERATION_DELAY.toMillis());

        executor.schedule(this::iteration, alignedStartDelayMillis, TimeUnit.MILLISECONDS);
    }

    public double getShardHealthById(int shardId) {
        ShardHealthMeters meters = shardHealthById.get(shardId);
        if (meters == null) {
            return 0.0;
        }
        return meters.getHealth();
    }

    private void iteration() {
        updateShardHealths()
            .whenCompleteAsync((response, throwable) -> {
                if (throwable != null) {
                    logger.warn("failed to update shard healths: {}", throwable.getMessage());
                }

                long nowMillis = System.currentTimeMillis();

                long delayToNextIterationMillis =
                    ITERATION_DELAY.toMillis() - (nowMillis % ITERATION_DELAY.toMillis());

                executor.schedule(this::iteration, delayToNextIterationMillis, TimeUnit.MILLISECONDS);
            });
    }

    private CompletableFuture<Void> updateShardHealths() {
        try {
            Set<String> hosts = fetcherClient.getKnownHosts();

            List<CompletableFuture<ShardsHealthResponse>> features =
                hosts.stream()
                    .map(this::loadShardHealthsByHostOrEmpty)
                    .collect(Collectors.toList());

            return CompletableFutures.allOf(features).thenAccept(responses -> {
                Int2ObjectMap<FetcherShardHealth> shardHealthById =
                    computeShardHealthByIdFromResponses(responses);

                updateMeters(shardHealthById);
            });
        } catch (Throwable t) {
            RuntimeException wrappedException =
                new RuntimeException("unexpectedly failed to update shard healths", t);
            return CompletableFuture.failedFuture(wrappedException);
        }
    }

    private Int2ObjectMap<FetcherShardHealth> computeShardHealthByIdFromResponses(List<ShardsHealthResponse> responses) {
        int shardsCountFromAllHosts =
            responses.stream()
                .mapToInt(ShardsHealthResponse::getShardsCount)
                .sum();

        var healthByShardId = new Int2ObjectOpenHashMap<FetcherShardHealth>(shardsCountFromAllHosts);
        for (ShardsHealthResponse response : responses) {
            for (FetcherShardHealth shardHealth : response.getShardsList()) {

                FetcherShardHealth prevShardHealth = healthByShardId.get(shardHealth.getNumId());
                FetcherShardHealth mergedShardHealth = mergeCrossHostHealths(prevShardHealth, shardHealth);

                healthByShardId.put(shardHealth.getNumId(), mergedShardHealth);
            }
        }

        return healthByShardId;
    }

    private static FetcherShardHealth mergeCrossHostHealths(
        @Nullable FetcherShardHealth prevShardHealth,
        FetcherShardHealth nextShardHealth)
    {
        if (prevShardHealth == null) {
            return nextShardHealth;
        }

        // Don't consider shard health without hosts
        if (prevShardHealth.getUrlsFail() == 0 && prevShardHealth.getUrlsOk() == 0) {
            return nextShardHealth;
        }

        if (nextShardHealth.getUrlsFail() == 0 && nextShardHealth.getUrlsOk() == 0) {
            return prevShardHealth;
        }

        // Select shard health from host with least failed targets percent
        int prevHealthValue = computeInstantShardHealth(prevShardHealth);
        int nextHealthValue = computeInstantShardHealth(nextShardHealth);

        return prevHealthValue < nextHealthValue ? prevShardHealth : nextShardHealth;
    }

    private static int computeInstantShardHealth(FetcherShardHealth shardHealth) {
        return 100 * shardHealth.getUrlsFail() / (shardHealth.getUrlsOk() + shardHealth.getUrlsFail());
    }

    private CompletableFuture<ShardsHealthResponse> loadShardHealthsByHostOrEmpty(String host) {
        try {
            return fetcherClient.getFetcherShardsHealths(host).handle(((shardsHealthResponse, throwable) -> {
                if (throwable != null) {
                    logger.warn("failed to load shard healths from host {}: {}", host, throwable.getMessage());
                    return ShardsHealthResponse.getDefaultInstance();
                }

                return shardsHealthResponse;
            }));
        } catch (Throwable t) {
            logger.warn("unexpectedly failed to load shard health from host {}, {}", host, t.getMessage());
            return CompletableFuture.completedFuture(ShardsHealthResponse.getDefaultInstance());
        }
    }

    private void updateMeters(Int2ObjectMap<FetcherShardHealth> newShardHealthById) {
        try {
            IntList shardIds = new IntArrayList(shardHealthById.keySet());
            for (int i = 0; i < shardIds.size(); i++) {
                int shardId = shardIds.getInt(i);
                FetcherShardHealth health = newShardHealthById.remove(shardId);
                if (health == null) {
                    // shard was moved or deleted
                    this.shardHealthById.remove(shardId);
                } else {
                    var meters = shardHealthById.computeIfAbsent(shardId, id -> new ShardHealthMeters());
                    meters.update(health.getUrlsOk(), health.getUrlsFail());
                }
            }

            // add new shards
            newShardHealthById.forEach((shardId, health) -> {
                var meters = shardHealthById.computeIfAbsent(shardId, id -> new ShardHealthMeters());
                meters.update(health.getUrlsOk(), health.getUrlsFail());
            });
        } catch (Throwable t) {
            logger.warn("cannot update shards health", t);
        }
    }

    /**
     * SHARD HEALTH METERS
     */
    private static final class ShardHealthMeters extends TickMixin {
        private final ExpMovingAverage urlsOk;
        private final ExpMovingAverage urlsFail;

        private static final long TICK_INTERVAL_NANOS = ExpMovingAverage.oneMinute().getTickIntervalNanos();

        public ShardHealthMeters() {
            super(TICK_INTERVAL_NANOS);
            // Adjust TICK_INTERVAL_NANOS if EMA below are changed
            this.urlsOk = ExpMovingAverage.oneMinute();
            this.urlsFail = ExpMovingAverage.oneMinute();
        }

        @Override
        protected void onTick() {
            urlsOk.tick();
            urlsFail.tick();
        }

        void update(int urlsOk, int urlsFail) {
            tickIfNecessary();
            this.urlsOk.update(urlsOk);
            this.urlsFail.update(urlsFail);
        }

        double getHealth() {
            double okRate = urlsOk.getRate(TimeUnit.SECONDS);
            double failRate = urlsFail.getRate(TimeUnit.SECONDS);

            if (failRate >= 0 && okRate >= 0) {
                // 0%: all metrics are fetched without fails
                if (failRate == 0) {
                    return 0;
                }

                // 100%: no metrics are fetched successfully
                if (okRate == 0) {
                    return 100;
                }

                double failurePercent = 100 * failRate / (failRate + okRate);

                // keep only 3 digits after point
                return Math.round(failurePercent * 1000) / 1000.0;
            }
            return 0.0;
        }
    }
}
