package ru.yandex.direct.logicprocessor.components;

import java.time.Duration;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.function.BiPredicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import javax.annotation.ParametersAreNonnullByDefault;

import ru.yandex.direct.binlogbroker.logbroker_utils.reader.impl.LogbrokerBatchReaderImpl;
import ru.yandex.direct.ess.common.converter.LogicObjectWithSystemInfoConverter;
import ru.yandex.direct.ess.common.models.BaseLogicObject;
import ru.yandex.direct.ess.common.models.LogicObjectListWithInfo;
import ru.yandex.kikimr.persqueue.consumer.SyncConsumer;
import ru.yandex.kikimr.persqueue.consumer.transport.message.inbound.data.MessageBatch;
import ru.yandex.monlib.metrics.registry.MetricRegistry;

@ParametersAreNonnullByDefault
public class LogicObjectLogbrokerReader<T extends BaseLogicObject> extends LogbrokerBatchReaderImpl<LogicObjectListWithInfo<T>> {

    private int rowThreshold;
    private Duration readTimeThreshold;
    private LogicObjectWithSystemInfoConverter<T> logicObjectWithSystemInfoConverter;


    public LogicObjectLogbrokerReader(Supplier<SyncConsumer> logbrokerConsumerSupplier,
                                      boolean logbrokerNoCommit, Class<? extends BaseLogicObject> clazz,
                                      int rowThreshold,
                                      Duration readTimeThreshold, MetricRegistry metricRegistry,
                                      boolean needReadingOptimizations) {
        super(logbrokerConsumerSupplier, logbrokerNoCommit, metricRegistry, needReadingOptimizations);
        this.rowThreshold = rowThreshold;
        this.readTimeThreshold = readTimeThreshold;
        this.logicObjectWithSystemInfoConverter = new LogicObjectWithSystemInfoConverter<>(clazz);
    }


    @Override
    protected List<LogicObjectListWithInfo<T>> batchDeserialize(MessageBatch messageBatch) {
        return
                messageBatch.getMessageData()
                        .parallelStream()
                        .map(messageData -> logicObjectWithSystemInfoConverter.fromJson(messageData.getDecompressedData()))
                        .filter(logicObjectsWithSystemInfo -> Objects
                                .nonNull(logicObjectsWithSystemInfo.getLogicObjectsList()))
                        .collect(Collectors.toList());
    }

    @Override
    protected BiPredicate<Long, Duration> batchingThreshold() {
        return (rows, time) -> rows < rowThreshold && time.compareTo(readTimeThreshold) < 0;
    }

    @Override
    protected int count(List<LogicObjectListWithInfo<T>> logicObjects) {
        return logicObjects.stream().map(LogicObjectListWithInfo::getLogicObjectsList).mapToInt(Collection::size).sum();
    }
}
