package ru.yandex.chemodan.queller.rabbit;

import java.text.DecimalFormat;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import com.rabbitmq.client.AMQP;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.springframework.amqp.AmqpException;
import org.springframework.amqp.core.Binding;
import org.springframework.amqp.core.BindingBuilder;
import org.springframework.amqp.core.DirectExchange;
import org.springframework.amqp.core.Exchange;
import org.springframework.amqp.core.MessageListener;
import org.springframework.amqp.core.Queue;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.Connection;
import org.springframework.amqp.rabbit.connection.ConnectionListener;
import org.springframework.amqp.rabbit.core.RabbitAdmin;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.amqp.rabbit.support.CorrelationData;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.bolts.collection.IteratorF;
import ru.yandex.bolts.collection.ListF;
import ru.yandex.bolts.collection.MapF;
import ru.yandex.bolts.collection.Option;
import ru.yandex.chemodan.queller.celery.monitoring.CeleryMetrics;
import ru.yandex.chemodan.queller.celery.worker.WorkerStateProviderHolder;
import ru.yandex.commune.dynproperties.DynamicProperty;
import ru.yandex.misc.concurrent.CountDownLatches;
import ru.yandex.misc.ip.Host;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;
import ru.yandex.misc.monica.core.blocks.Health;
import ru.yandex.misc.monica.core.blocks.Instrument;
import ru.yandex.misc.monica.core.blocks.InstrumentedData;
import ru.yandex.misc.monica.core.blocks.RoundRobinCounter;
import ru.yandex.misc.monica.core.name.MetricName;
import ru.yandex.misc.monica.util.measure.MeasureInfo;
import ru.yandex.misc.time.TimeUtils;
import ru.yandex.misc.worker.spring.DelayingWorkerServiceBeanSupport;

/**
 * @author yashunsky
 */
public class RabbitConnection extends DelayingWorkerServiceBeanSupport {
    public static final Duration PING_TIMEOUT_COUNTER_DURATION = Duration.standardSeconds(60);

    private static final Logger logger = LoggerFactory.getLogger(RabbitConnection.class);

    private static final DynamicProperty<Double> maxSendingErrorRate1m =
            new DynamicProperty<>("queller.rabbit.connection.max-sending-error-rate-1m", 100.0);
    private static final DynamicProperty<Double> maxSendingDuration =
            new DynamicProperty<>("queller.rabbit.connection.max-sending-duration", 10000.0);
    private static final DynamicProperty<Double> maxPingTimeoutPerMinute =
            new DynamicProperty<>("queller.rabbit.connection.max-ping-timeout-per-minute", 10.0);
    private static final DynamicProperty<Integer> pingTimeout =
            new DynamicProperty<>("queller.rabbit.connection.ping-timeout", 10000);
    private static final DynamicProperty<Double> minSignificantRps =
            new DynamicProperty<>("queller.rabbit.connection.minSignificantRps", 10.0);

    public final RabbitConnectionPojo connectionData;
    public final Host host;

    public final Queue pingQueue;

    public final CachingConnectionFactory factory;
    public final RabbitAdmin admin;

    public final Instrument sending = new Instrument();

    private final ExecutorService connectionExecutorService;

    private final Instrument pings = new Instrument();
    private final RoundRobinCounter pingTimeoutCounter;
    private volatile boolean readyToReceive = false;
    private volatile boolean isConnected = false;
    private volatile boolean forceUse = false;

    private ListF<ConnectedListener> listeners;

    private Instant lastAction;

    private volatile Instant lastPing;

    private final RabbitTemplate pingTemplate;

    private final MapF<String, Exchange> maintainedExchanges;
    private final MapF<String, QueueState> maintainedQueues;
    private final MapF<String, Binding> maintainedBindings;

    private final WorkerStateProviderHolder celeryMonitorHolder;
    private final CeleryMetrics celeryMetrics;

    private final Duration maintenancePeriod;

    public RabbitConnection(RabbitConnectionPojo data, WorkerStateProviderHolder monitor,
            CeleryMetrics metrics, Duration maintenancePeriod, Duration serviceQueuesXExpires)
    {
        this.maintenancePeriod = maintenancePeriod;

        connectionData = data;
        host = data.host;
        pingQueue = RabbitQueues.withExpiration(
                "ping_" + UUID.randomUUID(), Duration.millis(pingTimeout.get()), serviceQueuesXExpires);

        celeryMonitorHolder = monitor;
        celeryMetrics = metrics;
        connectionExecutorService = Executors.newFixedThreadPool(connectionData.executorPoolSize);

        listeners = Cf.arrayList();

        pingTimeoutCounter = new RoundRobinCounter(PING_TIMEOUT_COUNTER_DURATION);

        factory = new CachingConnectionFactory(connectionData.host.toString(), connectionData.port.getPort());

        factory.setUsername(connectionData.username);
        factory.setPassword(connectionData.password);
        factory.setVirtualHost(connectionData.virtualHost);

        factory.setPublisherConfirms(true);

        try {
            factory.afterPropertiesSet();
        } catch (Exception e) {
            logger.error("Connection factory for {} configuration failed: {}", connectionData.host, e);
        }

        admin = new RabbitAdmin(factory);

        admin.afterPropertiesSet();

        lastAction = Instant.now();
        lastPing = Instant.now();

        factory.addConnectionListener(new ConnectionListener() {
            @Override
            public void onCreate(Connection connection) {
                isConnected = true;
                listeners.forEach(ConnectedListener::restart);
                declareRabbitComponents();
                logger.info("Connected to {}", host);
            }

            @Override
            public void onClose(Connection connection) {
                isConnected = false;
                readyToReceive = false;
                celeryMetrics.connectedToRabbits.set(Health.unhealthy("disconnected"), factory.getHost());
                logger.error("Disconnected from {}", host);
            }
        });

        maintainedExchanges = Cf.concurrentHashMap();
        maintainedBindings = Cf.concurrentHashMap();
        maintainedQueues = Cf.concurrentHashMap();

        pingTemplate = new RabbitTemplate(factory);
        pingTemplate.setConfirmCallback((correlationData, ack, cause) -> {
            // confirmCallback required if factory's publisherConfirms is set to true
        });
        pingTemplate.setExchange(pingQueue.getName());
        pingTemplate.setRoutingKey(pingQueue.getName());

        DirectExchange pingExchange = new DirectExchange(pingQueue.getName(), false, false);
        Binding pingBinding = BindingBuilder.bind(pingQueue).to(pingExchange).withQueueName();

        declareQueue(pingQueue);
        declareExchange(pingExchange);
        declareBinding(pingBinding);

        ConnectedListener pingListener = new ConnectedListener(this);

        listeners.add(pingListener);

        pingListener.setMessageListener((MessageListener) msg -> {

            Option<Instant> pingInstant = RabbitPingMessage.parseSafe(msg);

            if (!pingInstant.isPresent()) {
                logger.warn("Unable to resolve instant and host from ping message: {}", new String(msg.getBody()));
                return;
            }

            Duration duration = new Duration(pingInstant.get(), Instant.now());

            updatePing(duration);
        });

        pingListener.setQueues(pingQueue);
        pingListener.start();

        setSleepBeforeFirstRun(false);
        setDelay(this.maintenancePeriod);
    }

    public boolean isActive() {
        return lastAction.isAfter(Instant.now().minus(maintenancePeriod.multipliedBy(2)));
    }

    public RabbitConnectionHealth getHealth() {
        InstrumentedData sendingData = sending.apply();
        Double pingTimeoutPerMinute = pingTimeoutCounter.apply().doubleValue();
        InstrumentedData pingData = pings.apply();

        //TODO  think about better criteria

        return new RabbitConnectionHealth(connectionData.host,
                isActive(), isConnected, readyToReceive, forceUse,
                sendingData.statisticData().getMeter().getAverage1Min(),
                sendingData.errorRate().getAverage1Min(),
                sendingData.statisticData().getQuantiles().asMap().getTs(0.9),
                pingTimeoutPerMinute,
                pingData.statisticData().getQuantiles().asMap().getTs(0.9),
                maxSendingErrorRate1m.get(), maxSendingDuration.get(),
                maxPingTimeoutPerMinute.get(), pingTimeout.get(), minSignificantRps.get());
    }

    public boolean canBeUsed() {
        return getHealth().canBeUsed;
    }

    public QueueState getQueueState(String name) {
        return maintainedQueues.getOrThrow(name, "Undeclared queue ", name);
    }

    public void declareRabbitComponents() {
        readyToReceive = executeWithTimeout(
                () -> {
                    maintainedQueues.forEach((n, q) -> maintainedQueues.put(n, doDeclareQueue(q.queue)));
                    maintainedExchanges.values().forEach(admin::declareExchange);
                    maintainedBindings.values().forEach(admin::declareBinding);
                    celeryMetrics.connectedToRabbits.set(Health.healthy("connected"), factory.getHost());
                    return true;
                },
                false,
                connectionData.declarationTimeout
        );

        if (!readyToReceive) {
            celeryMetrics.connectedToRabbits.set(Health.unhealthy("declaration failed"), factory.getHost());
            logger.error("Rabbit component declaration failed at {}", factory.getHost());
        }
    }

    public void declareExchange(Exchange exchange) {
        maintainedExchanges.put(exchange.getName(), exchange);
        declareRabbitComponents();
    }

    public void declareQueue(Queue queue) {
        maintainedQueues.put(queue.getName(), new QueueState(queue, 0, 0));
        declareRabbitComponents();
    }

    public void declareBinding(Binding binding) {
        maintainedBindings.put(binding.toString(), binding);
        declareRabbitComponents();
    }

    public boolean deleteQueue(String queueName) {
        maintainedQueues.removeTs(queueName);
        return executeWithTimeout(() -> admin.deleteQueue(queueName), false, connectionData.declarationTimeout);
    }

    private void updatePing(Duration pingDuration) {
        pings.update(new MeasureInfo(pingDuration, true));
        if (pingDuration.getMillis() > pingTimeout.get()) {
            pingTimeoutCounter.inc();
        }

        lastPing = Instant.now();
    }

    public void setForceUse(boolean forceUse) {
        this.forceUse = forceUse;
    }

    public boolean runWithTimeout(Runnable action, Duration timeout) {
        return executeWithTimeout(() -> { action.run(); return true; }, false, timeout);
    }

    public <T> T executeWithTimeout(Callable<T> action, T returnOnFailure, Duration timeout) {
        Future<T> future = connectionExecutorService.submit(action::call);
        try {
            return future.get(timeout.getMillis(), TimeUnit.MILLISECONDS);
        } catch (InterruptedException | ExecutionException | TimeoutException | AmqpException e) {
            readyToReceive = false;
            logger.error("Action with {}, t/o: {}, failed: {}", connectionData, timeout.getMillis(), e);
            future.cancel(true);
            return returnOnFailure;
        } catch (Throwable t) {
            readyToReceive = false;
            throw t;
        }
    }

    public synchronized void addListener(ConnectedListener listener) {
        listeners.add(listener);
    }

    private synchronized void restartListeners() {
        listeners.forEach(ConnectedListener::restart);
    }

    private static final DecimalFormat rpsFormat = new DecimalFormat("#.###");

    private void logRpsSafe(String actionName, int count, int totalCount, Duration duration) {
        double rps = (double) count / ((duration.getMillis() > 0 ? duration.getMillis() : 1) * 1000);

        logger.info("Messages " + actionName + ": {} of {} in {}. RPS: {}",
                count, totalCount, TimeUtils.toSecondsString(duration), rpsFormat.format(rps));
    }

    public ListF<SendResult> sendMessages(ListF<RoutedMessage> routedMessages, boolean confirm) {
        if (routedMessages.isEmpty()) {
            return Cf.list();
        }

        ListF<RoutedMessageWithIndex> messages = routedMessages.zipWithIndex().map(RoutedMessageWithIndex::new);
        SendResult[] result = Cf.repeat(SendResult.skipped(), routedMessages.size()).toArray(SendResult.class);

        ListF<String> correlationIds = messages.map(m -> UUID.randomUUID().toString());
        MapF<String, Boolean> confirmations = Cf.concurrentHashMap();

        RabbitTemplate template = new RabbitTemplate(factory);
        CountDownLatch confirmationLatch = new CountDownLatch(routedMessages.size());

        template.setConfirmCallback(
                (correlationData, ack, cause) -> {
                    if (correlationIds.containsTs(correlationData.getId())) {
                        confirmations.put(correlationData.getId(), ack);
                        confirmationLatch.countDown();
                    }
                }
        );

        IteratorF<RoutedMessageWithIndex> messagesIt = messages.iterator();

        Instant start = Instant.now();

        boolean sentOk = runWithTimeout(() -> {
            while (messagesIt.hasNext() && !Thread.interrupted()) {
                RoutedMessageWithIndex msg = messagesIt.next();
                try {
                    confirmations.put(correlationIds.get(msg.index), false);
                    template.send(
                            msg.message.exchange,
                            msg.message.queueName.getOrElse(""),
                            msg.message.message,
                            new CorrelationData(correlationIds.get(msg.index)));

                    result[msg.index] = SendResult.sentNotConfirmed();

                } catch (AmqpException e) {
                    logger.error("Failed to send message to {}: {}", host, e);
                    result[msg.index] = SendResult.error(e);
                    break;
                }
            }
        }, connectionData.batchSendingTimeout);

        if (!sentOk) {
            logger.error("Failed to send messages to {} because of timeout", connectionData.host);
        }

        int sentCount = Cf.x(result).count(SendResult::isSent);

        for (int i = 0; i < routedMessages.size() - sentCount; i ++) {
            confirmationLatch.countDown();
        }

        Duration duration = new Duration(start, Instant.now());

        if (sentCount > 0) {
            sending.update(new MeasureInfo(duration.dividedBy(sentCount), true));
        }
        logRpsSafe("submitted", sentCount, messages.size(), duration);

        if (confirm && sentCount > 0) {
            CountDownLatches.await(confirmationLatch, connectionData.batchConfirmationTimeout);

            ListF<Boolean> confirmationResult = correlationIds.map(confirmations::getTs);

            duration = new Duration(start, Instant.now());
            logRpsSafe("confirmed", confirmationResult.count(a -> a), sentCount, duration);

            IteratorF<Boolean> confirmedIt = confirmationResult.iterator();

            messages.iterator().take(sentCount).forEachRemaining(
                    msg -> result[msg.index] = confirmedIt.next()
                            ? SendResult.sentConfirmed()
                            : SendResult.sentNotConfirmed());

            return Cf.x(result);
        } else {
            return Cf.x(result);
        }
    }

    @Override
    public void stop() {
        connectionExecutorService.shutdownNow();
        super.stop();
    }

    @Override
    protected void execute() throws Exception {
        lastAction = Instant.now();

        runWithTimeout(() -> pingTemplate.send(RabbitPingMessage.create()),
                Duration.millis(maxSendingDuration.get().longValue()));
        sending.update(new MeasureInfo(new Duration(lastAction, Instant.now()), true));

        Instant afterPing = Instant.now();

        for (int tick = 0; tick < maintenancePeriod.getStandardSeconds(); tick++) {
            pingTimeoutCounter.tick(System.currentTimeMillis());
        }

        sending.tick(System.currentTimeMillis());
        pings.tick(System.currentTimeMillis());

        if (lastPing.isBefore(Instant.now().minus(pingTimeout.get()))) {
            pingTimeoutCounter.inc();
            restartListeners();
        }

        Instant afterRestarts = Instant.now();

        if (!readyToReceive) {
            declareRabbitComponents();
        } else {
            obtainQueuesStates();
        }

        Instant afterObtainingStatesOrDeclaration = Instant.now();

        closeUnexpectedChannels();

        Instant afterClosingChannels = Instant.now();

        if (!isActive()) {
            logger.error("Connection regular operations took too much time. See details:"
                    + " ping - {}, restarts - {}, states/declaration - {}, closing channels - {}",
                    TimeUtils.toSecondsString(new Duration(lastAction, afterPing)),
                    TimeUtils.toSecondsString(new Duration(afterPing, afterRestarts)),
                    TimeUtils.toSecondsString(new Duration(afterRestarts, afterObtainingStatesOrDeclaration)),
                    TimeUtils.toSecondsString(new Duration(afterObtainingStatesOrDeclaration, afterClosingChannels))
            );
        }
    }

    private void obtainQueuesStates() {
        ListF<String> queueNames = maintainedQueues.keys();

        logger.debug("Obtaining states of queues {} from {}", queueNames, host);

        if (!runWithTimeout(
                () -> maintainedQueues.forEach((n, q) -> maintainedQueues.put(n, doObtainQueueState(q.queue))),
                connectionData.getPropertiesTimeout))
        {
            celeryMetrics.connectedToRabbits.set(Health.worried("queues states obtaining failed"), "" + host);
            logger.warn("Queues states obtaining failed at {}", host);
        }
    }

    private QueueState doDeclareQueue(Queue queue) {
        return countMetrics(admin.getRabbitTemplate().execute(channel -> {
            AMQP.Queue.DeclareOk res = channel.queueDeclare(
                    queue.getName(), queue.isDurable(), queue.isExclusive(),
                    queue.isAutoDelete(), queue.getArguments());

            return new QueueState(queue, res.getMessageCount(), getConsumersCount(res));
        }));
    }

    private QueueState doObtainQueueState(Queue queue) {
        return countMetrics(admin.getRabbitTemplate().execute(channel -> {
            AMQP.Queue.DeclareOk res = channel.queueDeclarePassive(queue.getName());

            int messageCount = res.getMessageCount();
            int consumersCount = getConsumersCount(res);

            logger.debug("Received queue state. messages={}, consumers={}", messageCount, consumersCount);
            return new QueueState(queue, messageCount, consumersCount);
        }));
    }

    private int getConsumersCount(AMQP.Queue.DeclareOk declareOk) {
        if (celeryMonitorHolder.isSet()) {
            return celeryMonitorHolder.get().getWorkersState().values().iterator()
                    .filter(ws -> ws.rabbitHosts.containsTs(connectionData.host.format()))
                    .filter(ws -> ws.isJavaWorker || ws.processesCount.exists(c -> c > 0))
                    .filter(ws -> ws.queues.containsTs(declareOk.getQueue()))
                    .count();
        } else {
            return declareOk.getConsumerCount();
        }
    }

    private QueueState countMetrics(QueueState state) {
        MetricName name = new MetricName(Cf.list(state.queue.getName(), connectionData.host.format()));

        celeryMetrics.queuesStates.set(state.messageCount, name.withSuffix("messages"));
        celeryMetrics.queuesStates.set(state.consumerCount, name.withSuffix("consumers"));

        return state;
    }

    private void closeUnexpectedChannels() {
        listeners.forEach(ConnectedListener::closeUnexpectedChannels);
    }
}
