package ru.yandex.travel.commons.yt;

import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Tuple2;
import ru.yandex.travel.commons.messaging.AbstractMessageBusAndKeyValueStorage;
import ru.yandex.travel.commons.messaging.ClusterHealthStatus;
import ru.yandex.travel.commons.messaging.CompressionSettings;
import ru.yandex.travel.commons.messaging.CompressionUtils;
import ru.yandex.travel.commons.messaging.Envelope;
import ru.yandex.travel.commons.proto.EMessageCodec;
import ru.yandex.travel.commons.retry.Retry;
import ru.yandex.travel.commons.retry.RetryException;
import ru.yandex.travel.commons.retry.RetryStrategyBuilder;
import ru.yandex.yt.rpcproxy.ETransactionType;
import ru.yandex.yt.ytclient.proxy.ApiServiceTransactionOptions;
import ru.yandex.yt.ytclient.proxy.LookupRowsRequest;
import ru.yandex.yt.ytclient.proxy.ModifyRowsRequest;
import ru.yandex.yt.ytclient.proxy.YtClient;
import ru.yandex.yt.ytclient.rpc.RpcError;
import ru.yandex.yt.ytclient.tables.ColumnValueType;
import ru.yandex.yt.ytclient.tables.TableSchema;
import ru.yandex.yt.ytclient.wire.UnversionedRow;


public class SingleClusterYtAdapter extends AbstractMessageBusAndKeyValueStorage {
    public static final int ROW_CONFLICT = 1700;
    private static final Logger logger = LoggerFactory.getLogger(SingleClusterYtAdapter.class);
    private static final String TIMESTAMP = "Timestamp";
    private static final String MESSAGE_TYPE = "MessageType";
    private static final String CODEC = "Codec";
    private static final String BYTES = "Bytes";
    private static final String MESSAGE_ID = "MessageId";
    private static final String EXPIRE_TIMESTAMP = "ExpireTimestamp";
    private final String clusterName;
    private final String metricPrefix;
    private final YtClusterPropertiesInterface clusterConfig;
    private final ConnectionFactory connectionFactory;
    private final CompressionSettings compressionSettings;
    private final ConcurrentLinkedQueue<EnvelopeWrapper> writeQueue;
    private final Retry retryHelper;
    private final boolean useUniquePingId;
    private final String pingMessageId;
    private final AtomicBoolean hasAliveDestinations = new AtomicBoolean(false);
    private final AtomicBoolean writerIsAlive = new AtomicBoolean(false);
    private final AtomicBoolean readerIsAlive = new AtomicBoolean(false);
    private final AtomicBoolean healthCheckThreadIsRunning = new AtomicBoolean(false);
    private final AtomicInteger activeWrites = new AtomicInteger(0);
    private final AtomicLong lastWriteTestTimeInMillis = new AtomicLong(-1);
    private final AtomicLong lastReadTestTimeInMillis = new AtomicLong(-1);
    private final ScheduledExecutorService healthCheckExecutorService;
    private final ScheduledExecutorService sendingExecutorService;
    private final AtomicBoolean writerIsInitialized = new AtomicBoolean(false);
    private final AtomicBoolean stopCalled = new AtomicBoolean(false);
    // Metrics.
    private final Timer sendTotalTimer;
    private final Timer sendActivityTimer;
    private final Timer sendStallTimer;
    private final Counter byteSentCounter;
    private final Counter envelopeSentCounter;
    private final Counter sendSuccessCounter;
    private final Counter sendFailCounter;
    private final Counter sendDropCounter;
    private final Timer readTimer;
    private final Counter byteReadCounter;
    private final Counter readSuccessCounter;
    private final Counter readFailCounter;
    private final Counter readEmptyCounter;
    private final Counter writerExceptions;
    private final DistributionSummary writerBatchSize;
    private final boolean sorted;
    private final boolean writeHealthChecks;
    private final boolean readHealthChecks;
    private final Supplier<Message> pingMessageSupplier;
    private final Duration pingMessageLifetime;
    private ScheduledFuture<?> writerThread = null;
    private TableSchema tableSchema;


    public SingleClusterYtAdapter(String clusterName, String metricPrefix, YtClusterPropertiesInterface clusterConfig,
                                  CompressionSettings compressionSettings, Retry retryHelper,
                                  ConnectionFactory connectionFactory, boolean sorted, boolean writeHealthChecks,
                                  boolean readHealthChecks, Supplier<Message> pingMessageSupplier,
                                  Duration pingMessageLifetime, boolean useUniquePingId) {
        this.clusterName = clusterName;
        this.metricPrefix = metricPrefix;
        this.clusterConfig = clusterConfig;
        this.connectionFactory = connectionFactory;
        this.writeHealthChecks = writeHealthChecks;
        this.readHealthChecks = readHealthChecks;
        this.writeQueue = new ConcurrentLinkedQueue<>();
        this.retryHelper = retryHelper;
        this.useUniquePingId = useUniquePingId;
        this.pingMessageId = useUniquePingId ? UUID.randomUUID().toString() : "0-0-0-0-0";
        this.compressionSettings = compressionSettings;
        this.sorted = sorted;
        this.pingMessageSupplier = pingMessageSupplier;
        this.pingMessageLifetime = pingMessageLifetime;

        healthCheckExecutorService = Executors.newScheduledThreadPool(1,
                new ThreadFactoryBuilder().setNameFormat(clusterName + "-health-check-thread-%d").build());
        sendingExecutorService = Executors.newScheduledThreadPool(this.clusterConfig.getMaxConcurrentWrites(),
                new ThreadFactoryBuilder().setNameFormat(clusterName + "-send-to-yt-thread-%d").build());

        // Iinitialize metrics.
        sendStallTimer = YtAdapterMetricsHelper.createTimer(
                String.format("yt.%s.cluster.sendStallTime", metricPrefix), "clusterName", clusterName);
        sendActivityTimer = YtAdapterMetricsHelper.createTimer(
                String.format("yt.%s.cluster.sendActiveTime", metricPrefix), "clusterName", clusterName);
        sendTotalTimer = YtAdapterMetricsHelper.createTimerWithExtendedBuckets(
                String.format("yt.%s.cluster.sendTotalTime", metricPrefix), "clusterName", clusterName);

        byteSentCounter = Metrics.counter(String.format("yt.%s.cluster.bytesSent", metricPrefix), "clusterName",
                clusterName);
        envelopeSentCounter = Metrics.counter(String.format("yt.%s.cluster.envelopesSent", metricPrefix),
                "clusterName", clusterName);
        sendSuccessCounter = Metrics.counter(String.format("yt.%s.cluster.sendsSucceeded", metricPrefix),
                "clusterName", clusterName);
        sendFailCounter = Metrics.counter(String.format("yt.%s.cluster.sendsFailed", metricPrefix), "clusterName",
                clusterName);
        sendDropCounter = Metrics.counter(String.format("yt.%s.cluster.sendsDropped", metricPrefix), "clusterName",
                clusterName);

        readTimer = YtAdapterMetricsHelper.createTimer(
                String.format("yt.%s.cluster.readTime", metricPrefix), "clusterName", clusterName);

        byteReadCounter = Metrics.counter(String.format("yt.%s.cluster.bytesRead", metricPrefix), "clusterName",
                clusterName);
        readSuccessCounter = Metrics.counter(String.format("yt.%s.cluster.readSucceeded", metricPrefix), "clusterName"
                , clusterName);
        readFailCounter = Metrics.counter(String.format("yt.%s.cluster.readFailed", metricPrefix), "clusterName",
                clusterName);
        readEmptyCounter = Metrics.counter(String.format("yt.%s.cluster.readEmpty", metricPrefix), "clusterName",
                clusterName);
        writerExceptions = Metrics.counter(String.format("yt.%s.cluster.writerExceptions", metricPrefix), "clusterName",
                clusterName);

        writerBatchSize = DistributionSummary.builder(String.format("yt.%s.cluster.writerBatchSize", metricPrefix))
                .tag("clusterName", clusterName)
                .serviceLevelObjectives(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)
                .register(Metrics.globalRegistry);

        Gauge.builder(String.format("yt.%s.cluster.activeWrites", metricPrefix), activeWrites, AtomicInteger::get)
                .tag("clusterName", clusterName)
                .register(Metrics.globalRegistry);
        Gauge.builder(String.format("yt.%s.cluster.queueSize", metricPrefix), writeQueue, ConcurrentLinkedQueue::size)
                .tag("clusterName", clusterName)
                .register(Metrics.globalRegistry);
        buildSchema();
    }


    @Override
    public ClusterHealthStatus isAlive() {
        ClusterHealthStatus health = new ClusterHealthStatus();
        if (!hasAliveDestinations.get()) {
            health.setUp(false);
            health.setDetails(ImmutableMap.of(clusterName, "No alive destinations"));
        } else if ((!writerIsAlive.get() && writeHealthChecks) || (!readerIsAlive.get() && readHealthChecks)) {
            health.setUp(false);
            health.setDetails(ImmutableMap.of(clusterName, "No ping"));
        } else {
            health.setUp(true);
            health.setDetails(ImmutableMap.of(clusterName, String.format("Ping status: write %d ms, read %d ms",
                    lastWriteTestTimeInMillis.get(), lastReadTestTimeInMillis.get())));
        }
        return health;
    }

    @Override
    public boolean isHealthy() {
        return isAlive().isUp();
    }

    private void buildSchema() {
        TableSchema.Builder builder = new TableSchema.Builder();
        if (sorted) {
            builder
                    .setUniqueKeys(true)
                    .addKey(MESSAGE_ID, ColumnValueType.STRING);
        } else {
            builder
                    .setUniqueKeys(false)
                    .addValue(MESSAGE_ID, ColumnValueType.STRING);
        }
        builder
                .addValue(TIMESTAMP, ColumnValueType.UINT64)
                .addValue(MESSAGE_TYPE, ColumnValueType.STRING)
                .addValue(CODEC, ColumnValueType.UINT64)
                .addValue(BYTES, ColumnValueType.STRING)
                .addValue(EXPIRE_TIMESTAMP, ColumnValueType.UINT64);

        tableSchema = builder.build();
    }

    private void checkHealth() {
        if (getClient().getAliveDestinations().getOrDefault(clusterName, Collections.emptyList()).isEmpty()) {
            logger.error("Cluster '{}' has no alive destinations", clusterName);
            hasAliveDestinations.set(false);
            return;
        } else {
            hasAliveDestinations.set(true);
        }
        if (pingMessageSupplier == null) {
            logger.warn("No ping supplier specified, will neither send nor receive any pings");
            writerIsAlive.set(true);
            return;
        }
        if (writeHealthChecks) {
            logger.debug("Sending ping to cluster '{}'", clusterName);
            long timestamp = System.currentTimeMillis();
            try {
                Message ping = pingMessageSupplier.get();
                CompletableFuture<Void> pingFuture = sorted ? put(pingMessageId, ping, pingMessageLifetime) :
                        send(ping, pingMessageLifetime);
                pingFuture.get(clusterConfig.getHealthCheckTimeout().toNanos(), TimeUnit.NANOSECONDS);
                lastWriteTestTimeInMillis.set(System.currentTimeMillis() - timestamp);
                if (!writerIsAlive.getAndSet(true)) {
                    logger.info("Cluster '{}' is writable", clusterName);
                }
            } catch (Exception e) {
                var cause = e.getCause();
                if (cause instanceof RetryException) {
                    cause = (cause.getCause());
                }
                if (cause instanceof RpcError && !useUniquePingId && ((RpcError) cause).findMatchingError(ROW_CONFLICT) != null) {
                    logger.warn("RowConflict occurred while writing ping, assuming successful ping");
                    lastWriteTestTimeInMillis.set(System.currentTimeMillis() - timestamp);
                    if (!writerIsAlive.getAndSet(true)) {
                        logger.info("Cluster '{}' is writable", clusterName);
                    }
                } else {
                    logger.error(String.format("Ping to cluster '%s' failed with exception", clusterName),
                            e.getCause());
                    if (writerIsAlive.getAndSet(false)) {
                        logger.error(String.format("Cluster '%s' is not writable", clusterName), e.getCause());
                    }
                }
            }
        }
        if (readHealthChecks && sorted) {
            logger.debug("Polling for ping table path '{}'", clusterConfig.getTablePath());
            try {
                long timestamp = System.currentTimeMillis();
                get(pingMessageId, pingMessageSupplier.get().getClass()).get(clusterConfig.getHealthCheckTimeout().toNanos(), TimeUnit.NANOSECONDS);
                lastReadTestTimeInMillis.set(System.currentTimeMillis() - timestamp);
                if (!readerIsAlive.getAndSet(true)) {
                    logger.info("Cluster '{}' is readable", clusterName);
                }
            } catch (Exception e) {
                logger.error(String.format("Ping-polling of cluster '%s' failed with exception", clusterName), e);
                if (readerIsAlive.getAndSet(false)) {
                    logger.error(String.format("Cluster '%s' is not readable", clusterName), e);
                }
            }
        }
    }

    @Override
    public CompletableFuture<Void> send(Envelope envelope) {
        if (stopCalled.get()) {
            throw new RuntimeException(String.format("Tried to send to stopping cluster '%s'", clusterName));
        }
        if (!writerIsInitialized.getAndSet(true)) {
            writerThread = sendingExecutorService.schedule(
                    this::startSendingToYt,
                    clusterConfig.getSendingInterval().toNanos(),
                    TimeUnit.NANOSECONDS
            );
        }

        envelopeSentCounter.increment();
        byteSentCounter.increment(envelope.getBytes().length);
        if (writeQueue.size() >= clusterConfig.getMaxQueueSize()) {
            logger.error("Dropping message {} to cluster '{}' because queue is full", envelope, clusterName);
            sendDropCounter.increment();
            CompletableFuture<Void> result = new CompletableFuture<>();
            result.completeExceptionally(new RuntimeException("Sending queue is full"));
            return result;
        }

        EnvelopeWrapper wrapper = new EnvelopeWrapper(envelope);
        CompletableFuture<Void> result = wrapper.getFuture();
        writeQueue.add(wrapper);

        return result;
    }


    public Tuple2<String, Method> getProtoNameAndParser(Class<? extends Message> messageClass) {
        try {
            Descriptors.Descriptor descriptor =
                    (Descriptors.Descriptor) messageClass.getMethod("getDescriptor").invoke(null);
            String protoName = descriptor.getFullName();
            Method parserMethod = messageClass.getMethod("parseFrom", byte[].class);
            return Tuple2.tuple(protoName, parserMethod);
        } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException ex) {
            throw new RuntimeException("Unable to register message type " + messageClass.getSimpleName(), ex);
        }
    }

    @Override
    public <T extends Message> CompletableFuture<T> get(String key, Class<? extends T> messageClass) {
        if (!sorted) {
            throw new UnsupportedOperationException("'get' is supported only on sorted tables");
        }
        if (messageClass == null) {
            return CompletableFuture.completedFuture(null);
        }
        if (!hasAliveDestinations.get()) {
            // по каким-то причинам вызываемый ниже client.lookupRows оказывается блокирующим (не смотря на то, что вроде
            // бы возвращает фьючу) в ситуации, если у клиента в данный момент нет активных живых destinations.
            // Поэтому в этой ситуации важно возвращать пофейлившуюся фьючу до того, как пойдет вызов lookupRows - иначе
            // весь метод get тоже де-факто станет синхронным по client discovery
            return CompletableFuture.failedFuture(new RuntimeException("No Alive Destinations for " + this.clusterName));
        }
        YtClient client = getClient();
        LookupRowsRequest request = new LookupRowsRequest(clusterConfig.getTablePath(), tableSchema.toLookup())
                .addFilter(key)
                .addLookupColumns(MESSAGE_TYPE, CODEC, BYTES);
        long readStartTime = System.currentTimeMillis();
        var protoAndParser = getProtoNameAndParser(messageClass);
        String proto = protoAndParser.get1();
        Method parser = protoAndParser.get2();
        return client.lookupRows(request).thenApply(urs -> {
            long readEndTime = System.currentTimeMillis();
            readTimer.record(readEndTime - readStartTime, TimeUnit.MILLISECONDS);
            if (urs.getRows().size() == 0 || (urs.getRows().size() == 1 && urs.getRows().get(0) == null)) {
                readEmptyCounter.increment();
                return null;
            }
            for (UnversionedRow row : urs.getRows()) {
                String messageType = row.getValues().get(0).stringValue();
                if (!messageType.equals(proto)) {
                    throw new UnsupportedMessageTypeException(messageType, proto);
                }
                int codec = (int) row.getValues().get(1).longValue();
                byte[] bytes = row.getValues().get(2).bytesValue();
                byteReadCounter.increment(bytes.length);
                byte[] decompressed = CompressionUtils.decompress(EMessageCodec.forNumber(codec), bytes);
                try {
                    Message res = (Message) parser.invoke(null, (Object) decompressed);
                    readSuccessCounter.increment();
                    return messageClass.cast(res);
                } catch (IllegalAccessException | InvocationTargetException e) {
                    readFailCounter.increment();
                    throw new RuntimeException("Unable to deserialize message of type " + messageType, e);
                }
            }
            readEmptyCounter.increment();
            return null;
        });
    }

    @Override
    public CompressionSettings getCompressionSettings() {
        return compressionSettings;
    }

    private void startSendingToYt() {
        var semaphore = new Semaphore(clusterConfig.getMaxConcurrentWrites());
        var lastIterationStartTime = System.currentTimeMillis();
        var client = getClient();
        var finishMarkerReached = false;
        while (!finishMarkerReached) {
            try {
                var timeSinceLastWriteStartMs = System.currentTimeMillis() - lastIterationStartTime;
                var timeToSleepMs = clusterConfig.getSendingInterval().toMillis() - timeSinceLastWriteStartMs;
                if (!stopCalled.get() && writeQueue.size() < clusterConfig.getMaxBatchSize() && timeToSleepMs > 0) {
                    Thread.sleep(timeToSleepMs);
                }
                lastIterationStartTime = System.currentTimeMillis();

                List<EnvelopeWrapper> batch = new ArrayList<>();
                while (batch.size() < clusterConfig.getMaxBatchSize()) {
                    var wrapper = writeQueue.poll();
                    if (wrapper == null) {
                        break;
                    }
                    if (wrapper.getEnvelope() == null) {
                        finishMarkerReached = true;
                        break;
                    }
                    batch.add(wrapper);
                }

                writerBatchSize.record(batch.size());

                if (batch.isEmpty()) {
                    continue;
                }

                semaphore.acquire();
                activeWrites.incrementAndGet();

                var writeStartTime = System.currentTimeMillis();
                retryHelper.withRetry("YtSent", () -> sendImpl(client, batch, writeStartTime),
                        new RetryStrategyBuilder<Void>()
                                .retryOnException(
                                        ex -> ex instanceof RpcError && ((RpcError) ex).findMatchingError(ROW_CONFLICT) != null)
                                .setNumRetries(3)
                                .build())
                        .whenComplete((v, t) -> {
                            long writeEndTime = System.currentTimeMillis();
                            if (t != null) {
                                if (connectionFactory.getIsClosed() && t.getCause() instanceof ClosedChannelException) {
                                    logger.warn("{}: Could not send {} messages to cluster '{}' because of on-going " +
                                                    "shutdown",
                                            metricPrefix, batch.size(), clusterName);
                                } else {
                                    logger.error(String.format("%s: Could not send %s messages to cluster '%s'",
                                            metricPrefix, batch.size(),
                                            clusterName), t);
                                }
                                batch.forEach(wrapper -> wrapper.getFuture().completeExceptionally(t));
                                sendFailCounter.increment(batch.size());
                            } else {
                                logger.debug("{}: Successfully sent {} messages to cluster '{}'", metricPrefix, batch.size(),
                                        clusterName);
                                batch.forEach(wrapper -> wrapper.getFuture().complete(null));
                                sendSuccessCounter.increment(batch.size());
                            }
                            for (var wrapper : batch) {
                                sendActivityTimer.record(writeEndTime - writeStartTime, TimeUnit.MILLISECONDS);
                                sendTotalTimer.record(writeEndTime - wrapper.getEnvelope().getTimestamp(),
                                        TimeUnit.MILLISECONDS);
                            }
                        })
                        .whenComplete((v, y) -> {
                            semaphore.release();
                            activeWrites.decrementAndGet();
                        });
            } catch (Exception e) {
                logger.error(String.format("Exception yt-writer thread for cluster '%s'", clusterName), e);
                writerExceptions.increment();
            }
        }
    }

    private CompletableFuture<Void> sendImpl(YtClient client, List<EnvelopeWrapper> batch, long writeStartTime) {
        return client.startTransaction(new ApiServiceTransactionOptions(ETransactionType.TT_TABLET).setSticky(true))
                .thenCompose(t -> {
                    var mrr = new ModifyRowsRequest(clusterConfig.getTablePath(), tableSchema);
                    for (var wrapper : batch) {
                        sendStallTimer.record(writeStartTime - wrapper.getEnvelope().getTimestamp(),
                                TimeUnit.MILLISECONDS);
                        Map<String, Serializable> preparedEnvelope = prepare(wrapper.getEnvelope());
                        mrr.addInsert(preparedEnvelope);
                    }
                    return t.modifyRows(mrr).thenCompose(ignored -> t.commit());
                });
    }

    public void startHealthCheckThread() {
        if (!healthCheckThreadIsRunning.getAndSet(true)) {
            logger.info("Starting health check thread for cluster '{}'", clusterName);
            healthCheckExecutorService.scheduleAtFixedRate(
                    this::checkHealth,
                    ThreadLocalRandom.current().nextLong(clusterConfig.getHealthCheckInterval().toNanos()),
                    clusterConfig.getHealthCheckInterval().toNanos(),
                    TimeUnit.NANOSECONDS);
        }
    }

    @SuppressWarnings("UnstableApiUsage")
    public void stopHealthCheckThread() {
        logger.info("Stopping health check thread for cluster '{}'", clusterName);
        MoreExecutors.shutdownAndAwaitTermination(
                healthCheckExecutorService,
                clusterConfig.getHealthCheckTimeout().toNanos(),
                TimeUnit.NANOSECONDS);
    }

    @Override
    public void close() {
        stopHealthCheckThread();
        writeQueue.add(new EnvelopeWrapper(null)); //special marker
        stopCalled.set(true);
        try {
            writerThread.get();
        } catch (InterruptedException | ExecutionException e) {
            logger.error(String.format("Failed to wait for cluster stop on '%s'", clusterName), e);
        }
    }

    private YtClient getClient() {
        return connectionFactory.getClientForCluster(clusterName);
    }

    private Map<String, Serializable> prepare(Envelope envelope) {
        var builder = ImmutableMap.<String, Serializable>builder();
        builder.put(TIMESTAMP, envelope.getTimestamp())
                .put(MESSAGE_TYPE, envelope.getMessageType())
                .put(CODEC, envelope.getCodec().getNumber())
                .put(BYTES, envelope.getBytes())
                .put(MESSAGE_ID, envelope.getMessageId());
        if (envelope.getExpireTimestamp() != null) {
            builder.put(EXPIRE_TIMESTAMP, envelope.getExpireTimestamp());
        }
        return builder.build();
    }
}
