package ru.yandex.direct.binlogbroker.logbroker_utils.reader;

import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.direct.binlog.model.BinlogEventProtobuf;
import ru.yandex.direct.binlogbroker.logbroker_utils.models.BinlogEventWithOffset;
import ru.yandex.direct.binlogbroker.logbroker_utils.reader.impl.BatchingThresholdPredicate;
import ru.yandex.direct.binlogbroker.logbroker_utils.reader.impl.LogbrokerBatchReaderImpl;
import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.TraceProfile;
import ru.yandex.direct.utils.InterruptedRuntimeException;
import ru.yandex.kikimr.persqueue.consumer.SyncConsumer;
import ru.yandex.kikimr.persqueue.consumer.transport.message.inbound.data.MessageBatch;
import ru.yandex.kikimr.persqueue.consumer.transport.message.inbound.data.MessageData;
import ru.yandex.monlib.metrics.registry.MetricRegistry;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Collections.emptySet;
import static ru.yandex.direct.binlog.model.BinlogEvent.fromProtobuf;
import static ru.yandex.direct.binlog.model.BinlogEventProtobuf.Event.parseFrom;

@ParametersAreNonnullByDefault
public class BinlogLogbrokerReader extends LogbrokerBatchReaderImpl<BinlogEventWithOffset> {
    private static final Logger logger = LoggerFactory.getLogger(BinlogLogbrokerReader.class);

    private Set<String> requiredTables = emptySet();

    private final ForkJoinPool forkJoinPool;
    private final long rowsThreshold;
    private final Duration timeThreshold;
    private final Long bytesThreshold;

    private static final long DEFAULT_ROWS_THRESHOLD = 10_000;
    private static final Duration DEFAULT_TIME_THRESHOLD = Duration.ofSeconds(5);


    private BinlogLogbrokerReader(Supplier<SyncConsumer> logbrokerConsumerSupplier, boolean logbrokerNoCommit,
                                  MetricRegistry metricRegistry, boolean needReadingOptimization,
                                  @Nullable ForkJoinPool forkJoinPool, long rowsThreshold, Duration timeThreshold,
                                  @Nullable Long bytesThreshold) {
        super(logbrokerConsumerSupplier, logbrokerNoCommit, metricRegistry, needReadingOptimization);

        this.forkJoinPool = forkJoinPool;
        this.rowsThreshold = rowsThreshold;
        this.timeThreshold = timeThreshold;
        this.bytesThreshold = bytesThreshold;
    }

    @Override
    protected List<BinlogEventWithOffset> batchDeserialize(MessageBatch messageBatch) {
        try (TraceProfile profile = Trace.current().profile("binlog_logbroker_reader.deserialize")) {
            int partition = messageBatch.getPartition();
            Function<Stream<MessageData>, List<BinlogEventWithOffset>> parse =
                    s -> s.map(DecompressedMessageData::new)
                            .filter(decompressedData -> {
                                if (requiredTables.isEmpty()) {
                                    return true;
                                }
                                String binlogTable = decompressedData.getTable();
                                return Objects.nonNull(binlogTable) && requiredTables.contains(binlogTable);
                            })
                            .map(decompressedData -> decompressedData.toBinlogEventWithOffset(partition))
                            .filter(Objects::nonNull)
                            .collect(Collectors.toList());

            if (forkJoinPool != null) {
                try {
                    return forkJoinPool.submit(() -> parse.apply(messageBatch.getMessageData().parallelStream())).get();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new InterruptedRuntimeException(e);
                } catch (ExecutionException e) {
                    throw new LogbrokerReaderException(e);
                }
            } else {
                return parse.apply(messageBatch.getMessageData().stream());
            }
        }

    }

    public static BinlogEventWithOffset parse(MessageData messageData, byte[] decompressedData, int partition)
            throws InvalidProtocolBufferException {
        return new BinlogEventWithOffset(fromProtobuf(parseFrom(decompressedData)), messageData.getOffset(), partition,
                messageData.getMessageMeta().getSeqNo());
    }

    @Override
    protected int count(List<BinlogEventWithOffset> binlogs) {
        return binlogs.stream().mapToInt(binlog -> binlog.getEvent().getRows().size()).sum();
    }

    @Override
    protected BatchingThresholdPredicate<Long, Duration, Long> batchingThresholdPredicate() {
        if (bytesThreshold == null) {
            return (rows, time, bytes) -> rows < rowsThreshold && time.compareTo(timeThreshold) < 0;
        }

        return (rows, time, bytes) -> rows < rowsThreshold && time.compareTo(timeThreshold) < 0
                && bytes < bytesThreshold;
    }

    public BinlogLogbrokerReader requiredTables(Set<String> requiredTables) {
        this.requiredTables = requiredTables;

        return this;
    }

    public static class Builder {
        private Supplier<SyncConsumer> logbrokerConsumerSupplier;
        private boolean logbrokerNoCommit = false;
        private MetricRegistry metricRegistry;
        private boolean needReadingOptimization = false;
        private ForkJoinPool forkJoinPool;
        private long rowsThreshold = DEFAULT_ROWS_THRESHOLD;
        private Duration timeThreshold = DEFAULT_TIME_THRESHOLD;
        private Long bytesThreshold;

        public Builder withLogbrokerConsumerSupplier(Supplier<SyncConsumer> logbrokerConsumerSupplier) {
            this.logbrokerConsumerSupplier = logbrokerConsumerSupplier;
            return this;
        }

        public Builder withLogbrokerNoCommit(boolean logbrokerNoCommit) {
            this.logbrokerNoCommit = logbrokerNoCommit;
            return this;
        }

        public Builder withMetricRegistry(MetricRegistry metricRegistry) {
            this.metricRegistry = metricRegistry;
            return this;
        }

        public Builder withNeedReadingOptimization(boolean needReadingOptimization) {
            this.needReadingOptimization = needReadingOptimization;
            return this;
        }

        public Builder withForkJoinPool(ForkJoinPool forkJoinPool) {
            this.forkJoinPool = forkJoinPool;
            return this;
        }

        public Builder withRowsThreshold(long rowsThreshold) {
            this.rowsThreshold = rowsThreshold;
            return this;
        }

        public Builder withTimeThreshold(Duration timeThreshold) {
            this.timeThreshold = timeThreshold;
            return this;
        }

        public Builder withBytesThreshold(long bytesThreshold) {
            this.bytesThreshold = bytesThreshold;
            return this;
        }

        public BinlogLogbrokerReader build() {
            checkArgument(logbrokerConsumerSupplier != null, "Null consumer supplier!");
            return new BinlogLogbrokerReader(logbrokerConsumerSupplier, logbrokerNoCommit, metricRegistry,
                    needReadingOptimization, forkJoinPool, rowsThreshold, timeThreshold, bytesThreshold);
        }
    }

    private class DecompressedMessageData {
        private MessageData messageData;
        private byte[] decompressedData;

        DecompressedMessageData(MessageData messageData) {
            this.messageData = messageData;
            this.decompressedData = messageData.getDecompressedData();
        }

        public MessageData getMessageData() {
            return messageData;
        }

        public byte[] getDecompressedData() {
            return decompressedData;
        }

        public String getTable() {
            try {
                return BinlogEventProtobuf.EventTableOnly.parseFrom(decompressedData)
                        .getTable();
            } catch (InvalidProtocolBufferException ex) {
                logger.error("Failed to batchProcessing protobuf", ex);
                return null;
            }
        }

        BinlogEventWithOffset toBinlogEventWithOffset(int partition) {
            try {
                return parse(messageData, decompressedData, partition);
            } catch (InvalidProtocolBufferException ex) {
                logger.error("Failed to batchProcessing protobuf", ex);
                return null;
            }
        }
    }
}
