package ru.yandex.travel.orders.client;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.Channel;
import lombok.extern.slf4j.Slf4j;


@Slf4j
public class HAGrpcChannelFactory implements ChannelConsumer {

    private final Function<String, ChannelInfo> channelInfoBuilder;

    private final AtomicInteger nextChannelCounter = new AtomicInteger(0);
    private final ConcurrentHashMap<String, ChannelInfo> channels = new ConcurrentHashMap<>();
    // given a channel returns a future completed with boolean value (master channel or not)
    private final Function<Channel, CompletableFuture<ChannelState>> pingProducer;
    private final Duration pingTimeout;
    private ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
            new ThreadFactoryBuilder()
                    .setDaemon(true)
                    .setNameFormat("ServerInfoRefresh")
                    .build()
    );
    private Clock clock;

    private HAGrpcChannelFactory(Function<String, ChannelInfo> channelInfoBuilder, Function<Channel,
            CompletableFuture<ChannelState>> pingProducer, Duration pingTimeout, Duration initialRefreshDelay,
                                 Duration refreshDelay) {
        Preconditions.checkArgument(pingProducer != null, "Ping function must be provided");
        this.channelInfoBuilder = channelInfoBuilder;
        this.pingProducer = pingProducer;
        this.pingTimeout = pingTimeout;
        scheduledExecutorService.scheduleWithFixedDelay(
                () -> {
                    try {
                        refreshServers();
                    } catch (Exception e) {
                        // catching any exception as the first uncaught will stop all scheduling
                        log.error("An error occurred refreshing server info", e);
                    }
                }, initialRefreshDelay.toNanos(), refreshDelay.toNanos(), TimeUnit.NANOSECONDS
        );
        this.clock = Clock.systemUTC();
    }

    @Deprecated
    public Channel getAnyChannel() {
        return getFastestChannel();
    }

    public Channel getFastestChannel() {
        Optional<ChannelInfo> channel =
                channels.values().stream().filter(ChannelInfo::isReady).min(Comparator.comparingDouble(ChannelInfo::getAverageLatency));
        if (channel.isPresent()) {
            return channel.get().getChannel();
        } else {
            throw new HAGrpcException("No alive channels present");
        }
    }

    public Channel getRoundRobinChannel() {
        int counter = nextChannelCounter.updateAndGet(i -> i == Integer.MAX_VALUE ? 0 : i + 1);
        List<ChannelInfo> readyChannels = channels.values().stream().filter(ChannelInfo::isReady)
                .sorted(Comparator.comparing(ChannelInfo::getLabel)).collect(Collectors.toList());
        if (readyChannels.size() == 0) {
            throw new HAGrpcException("No ready channels present");
        }
        return readyChannels.get(counter % readyChannels.size()).getChannel();
    }

    public Channel getRoundRobinSlavePreferredChannel() {
        // TODO (mbobrov): add better synchronization to this place
        int counter = nextChannelCounter.updateAndGet(i -> i == Integer.MAX_VALUE ? 0 : i + 1);
        Instant now = Instant.now(clock);
        List<ChannelInfo> readyChannels = channels.values().stream()
                .filter(x -> x.isReady(now))
                .sorted(Comparator.comparing(ChannelInfo::getLabel)).collect(Collectors.toList());

        List<ChannelInfo> nonMaster = readyChannels.stream().filter(x -> !x.isMaster(now))
                .collect(Collectors.toUnmodifiableList());

        if (nonMaster.size() == 0) {
            List<ChannelInfo> masterCandidates = readyChannels.stream().filter(x -> x.isMaster(now))
                    .collect(Collectors.toUnmodifiableList());
            if (masterCandidates.size() == 0) {
                throw new HAGrpcException("No active master found");
            }
            if (masterCandidates.size() > 1) {
                throw new HAGrpcException("More than one master is present. " + masterCandidates.stream().map(ChannelInfo::getLabel).collect(Collectors.joining(",")));
            }
            return masterCandidates.get(0).getChannel();
        } else {
            return readyChannels.get(counter % readyChannels.size()).getChannel();
        }
    }

    /**
     * DO NOT USE THIS METHOD. Used only for testing purposes
     */
    public Channel getRandomChannel() {
        List<ChannelInfo> readyChannels = channels.values().stream().filter(ChannelInfo::isReady).collect(
                Collectors.collectingAndThen(Collectors.toCollection(ArrayList::new), list -> {
                    Collections.shuffle(list);
                    return list;
                }));
        if (readyChannels.size() == 0) {
            throw new HAGrpcException("No ready channels present");
        }
        return readyChannels.get(0).getChannel();
    }

    public Channel getMasterChannel() {
        List<ChannelInfo> masterCandidates = channels.values().stream()
                .filter(ChannelInfo::isMaster)
                .collect(Collectors.toList());
        if (masterCandidates.size() == 0) {
            throw new HAGrpcException("No active master found");
        }
        if (masterCandidates.size() > 1) {
            throw new HAGrpcException("More than one master is present. " + masterCandidates.stream().map(ChannelInfo::getLabel).collect(Collectors.joining(",")));
        }
        return masterCandidates.get(0).getChannel();
    }

    private void refreshServers() {
        channels.values().forEach(this::refreshServerInfo);
    }

    private <T> CompletableFuture<T> timeoutAfter(long timeout, TimeUnit unit) {
        CompletableFuture<T> result = new CompletableFuture<>();
        scheduledExecutorService.schedule(() -> result.completeExceptionally(new TimeoutException()), timeout, unit);
        return result;
    }

    private void refreshServerInfo(ChannelInfo channelInfo) {
        pingProducer.apply(channelInfo.getChannel())
                .acceptEither(
                        timeoutAfter(pingTimeout.toNanos(), TimeUnit.NANOSECONDS),
                        state -> {
                            log.debug("Successful ping of channel with label {}. Channel state: {}",
                                    channelInfo.getLabel(), state.name());
                            channelInfo.heartbeat(state);
                        }
                )
                .exceptionally(ex -> {
                            log.warn("Couldn't ping channel with label {} due to {}", channelInfo.getLabel(),
                                    ex.getMessage());
                            return null;
                        }
                );
    }

    @Override
    public void onChannelDiscovered(String channelLabel) {
        channels.computeIfAbsent(channelLabel, channelInfoBuilder);
    }

    @Override
    public void onChannelLost(String channelLabel) {
        channels.remove(channelLabel);
    }

    public static class Builder {
        private FailureDetectorProperties failureDetectorProperties;
        private Function<Channel, CompletableFuture<ChannelState>> pingProducer;
        private Function<String, LabeledChannel> channelBuilder;
        private ChannelSupplier channelSupplier;

        private Builder() {

        }

        public static Builder newBuilder() {
            return new Builder();
        }

        public Builder withFailureDetectorProperties(FailureDetectorProperties failureDetectorProperties) {
            this.failureDetectorProperties = failureDetectorProperties;
            return this;
        }

        public Builder withPingProducer(Function<Channel, CompletableFuture<ChannelState>> pingProducer) {
            this.pingProducer = pingProducer;
            return this;
        }

        public Builder withChannelSupplier(ChannelSupplier channelSupplier) {
            this.channelSupplier = channelSupplier;
            return this;
        }

        public Builder withChannelBuilder(Function<String, LabeledChannel> channelBuilder) {
            this.channelBuilder = channelBuilder;
            return this;
        }

        public HAGrpcChannelFactory build() {
            var channelInfoBuilder = channelBuilder.andThen(labeledChannel -> {
                log.info("Creating channel " + labeledChannel.getLabel());
                return new ChannelInfo(
                        labeledChannel.getLabel(), labeledChannel.getChannel(),
                        new StatefulPhiAccrualFailureDetector(
                                failureDetectorProperties.getThreshold(),
                                failureDetectorProperties.getMaxSampleSize(),
                                failureDetectorProperties.getMinStdDeviation().toMillis(),
                                failureDetectorProperties.getAcceptableHeartbeatPause().toMillis(),
                                failureDetectorProperties.getFirstHeartbeatEstimate().toMillis()
                        ));
            });
            var res = new HAGrpcChannelFactory(
                    channelInfoBuilder,
                    pingProducer,
                    failureDetectorProperties.getPingTimeout(),
                    failureDetectorProperties.getInitialRefreshDelay(),
                    failureDetectorProperties.getRefreshDelay()
            );
            if (channelSupplier != null) {
                channelSupplier.subscribe(res);
            }
            return res;
        }
    }
}
