package ru.yandex.travel.cpa.data_processing.flow.processors;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;

import ru.yandex.travel.cpa.data_processing.flow.logbroker.LogbrokerDataBatch;
import ru.yandex.travel.cpa.data_processing.flow.model.MessageDecoder;
import ru.yandex.travel.cpa.data_processing.flow.model.orders.OrderKey;
import ru.yandex.travel.cpa.data_processing.flow.model.snapshots.Snapshot;
import ru.yandex.travel.cpa.data_processing.flow.model.snapshots.SnapshotError;
import ru.yandex.travel.cpa.data_processing.flow.model.snapshots.SnapshotJsonDecoder;
import ru.yandex.travel.cpa.data_processing.flow.yt.SyncYtClient;

@Slf4j
public class SnapshotProcessor implements Processor {
    private final static String LOW_PRIORITY_SOURCE_ID_SUFFIX = "-low-priority";
    private final static String BREAKER_HASH = "PLEASE_STOP";
    private final static String UNIQUE_ERROR = "key_unique_error";
    private boolean processingFinished = false;
    private final MessageDecoder<Snapshot> snapshotDecoder = new SnapshotJsonDecoder();
    private final SyncYtClient<OrderKey, Snapshot> processedSnapshotsClient;
    private final SyncYtClient<OrderKey, Snapshot> snapshotsClient;
    private final SyncYtClient<OrderKey, SnapshotError> snapshotErrorsClient;
    private final SyncYtClient<OrderKey, OrderKey> orderQueueClient;
    private final SyncYtClient<OrderKey, OrderKey> slowOrderQueueClient;
    private final AtomicBoolean isClosed = new AtomicBoolean(false);

    private final Counter incomingNewSnapshotsCounter;
    private final Counter incomingOldSnapshotsCounter;
    private final Counter deduplicatedSnapshotsCounter;
    private final Timer readLagDistribution;
    private final Counter snapshotUniqueKeyErrorsCounter;

    @Value
    public static class DeduplicatedSnapshots {
        List<Snapshot> snapshotsToDelete;
        Map<OrderKey, List<Snapshot>> newSnapshots;
        List<Snapshot> processedSnapshots;
        List<SnapshotError> snapshotsNotUniqueKey;
        List<OrderKey> ordersToUpdate;
        int deduplicatedCount;
        int newCount;
        int oldCount;
    }

    @Value
    private static class DeduplicatedGroup {
        List<Snapshot> snapshots;
        List<SnapshotError> snapshotsNotUniqueKey;
    }

    private enum OrderProcessingPriority {
        REGULAR,
        LOW
    }

    @Value
    public static class SnapshotKey {
        String partnerName;
        String partnerOrderId;
        String hash;
        long updatedAt;

        public static SnapshotKey fromSnapshot(Snapshot snapshot) {
            return new SnapshotKey(
                    snapshot.getPartnerName(),
                    snapshot.getPartnerOrderId(),
                    snapshot.getHash(),
                    snapshot.getUpdatedAt()
            );
        }
    }

    @FunctionalInterface
    public interface ThrowingFunction<T, R> {
        /**
         * Applies this function to the given argument.
         *
         * @param t the function argument
         * @return the function result
         */
        R apply(T t) throws Exception;
    }

    public SnapshotProcessor(
            SyncYtClient<OrderKey, Snapshot> processedSnapshotsClient,
            SyncYtClient<OrderKey, Snapshot> snapshotsClient,
            SyncYtClient<OrderKey, SnapshotError> snapshotErrorsClient,
            SyncYtClient<OrderKey, OrderKey> orderQueueClient,
            SyncYtClient<OrderKey, OrderKey> slowOrderQueueClient
    ) {
        this.processedSnapshotsClient = processedSnapshotsClient;
        this.snapshotsClient = snapshotsClient;
        this.snapshotErrorsClient = snapshotErrorsClient;
        this.orderQueueClient = orderQueueClient;
        this.slowOrderQueueClient = slowOrderQueueClient;

        incomingNewSnapshotsCounter = Counter
                .builder("cpa.flow.incomingSnapshotsCount")
                .tag("status", "new")
                .register(Metrics.globalRegistry);

        incomingOldSnapshotsCounter = Counter
                .builder("cpa.flow.incomingSnapshotsCount")
                .tag("status", "old")
                .register(Metrics.globalRegistry);

        deduplicatedSnapshotsCounter = Counter
                .builder("cpa.flow.deduplicatedSnapshotsCount")
                .register(Metrics.globalRegistry);

        readLagDistribution = Timer.builder("cpa.flow.snapshotsReadLagDistribution")
                .serviceLevelObjectives(
                        Duration.ofSeconds(1),
                        Duration.ofSeconds(5),
                        Duration.ofSeconds(30),
                        Duration.ofSeconds(60),
                        Duration.ofSeconds(300),
                        Duration.ofSeconds(600),
                        Duration.ofSeconds(1800),
                        Duration.ofSeconds(3600),
                        Duration.ofSeconds(7200),
                        Duration.ofSeconds(10800),
                        Duration.ofSeconds(14400),
                        Duration.ofSeconds(18000)
                )
                .register(Metrics.globalRegistry);

        snapshotUniqueKeyErrorsCounter = Counter
                .builder("cpa.flow.errorsCount")
                .tag("type", "snapshotUniqueKey")
                .register(Metrics.globalRegistry);
    }

    public boolean process(LogbrokerDataBatch snapshots) throws Exception {
        while (!isClosed.get()) {
            try {
                tryProcess(snapshots);
                return true;
            } catch (ExecutionException | TimeoutException e) {
                log.warn("YT interaction error", e);
            }
        }
        return false;
    }

    public boolean isProcessingFinished() {
        return processingFinished;
    }

    public void close() {
        isClosed.set(true);
    }

    private void tryProcess(LogbrokerDataBatch snapshots) throws Exception {
        for (var snapshotGroup : getConvertedSnapshots(snapshots).entrySet()) {
            tryProcessGroup(snapshotGroup.getKey(), snapshotGroup.getValue());
        }
    }

    private void tryProcessGroup(OrderProcessingPriority priority, List<Snapshot> snapshots) throws Exception {
        var groupedSnapshots = getGroupedSnapshots(snapshots);
        var processedSnapshots = getProcessedSnapshots(groupedSnapshots);
        var deduplicatedSnapshots = getDeduplicatedSnapshots(groupedSnapshots, processedSnapshots, snapshotsClient::select);

        if (!hasUpdates(deduplicatedSnapshots)) {
            calculateMetrics(deduplicatedSnapshots);
            return;
        }

        var processedSnapshotsToSend = deduplicatedSnapshots.getProcessedSnapshots();

        try (var transaction = snapshotsClient.getTransaction()) {
            snapshotsClient.replace(
                    deduplicatedSnapshots.getSnapshotsToDelete(),
                    getSnapshotsFlat(deduplicatedSnapshots.getNewSnapshots()),
                    transaction
            );
            snapshotsClient.send(getSnapshotsFlat(deduplicatedSnapshots.getNewSnapshots()), transaction);
            switch (priority) {
                case LOW:
                    slowOrderQueueClient.send(deduplicatedSnapshots.getOrdersToUpdate(), transaction);
                    break;
                case REGULAR:
                    orderQueueClient.send(deduplicatedSnapshots.getOrdersToUpdate(), transaction);
                    break;
                default:
                    throw new IllegalStateException("Unexpected value: " + priority);
            }
            processedSnapshotsClient.send(processedSnapshotsToSend, transaction);
            snapshotErrorsClient.send(deduplicatedSnapshots.getSnapshotsNotUniqueKey());
            snapshotsClient.commitTransaction(transaction);
        }

        calculateMetrics(deduplicatedSnapshots);
    }

    public static Map<OrderKey, List<Snapshot>> getGroupedSnapshots(List<Snapshot> snapshots) {
        var groupedSnapshots = new HashMap<OrderKey, List<Snapshot>>();
        for (var snapshot : snapshots) {
            var key = new OrderKey(snapshot);
            var groupSnapshots = groupedSnapshots.computeIfAbsent(key, k -> new ArrayList<>());
            groupSnapshots.add(snapshot);
        }
        for (var group : groupedSnapshots.values()) {
            sortGroup(group);
        }
        return groupedSnapshots;
    }

    private static void sortGroup(List<Snapshot> group) {
        group.sort((a, b) -> {
            int diff = (int) (a.getUpdatedAt() - b.getUpdatedAt());
            if (diff == 0) {
                diff = a.getHash().compareTo(b.getHash());
            }
            return diff;
        });
    }

    private Map<OrderProcessingPriority, List<Snapshot>> getConvertedSnapshots(LogbrokerDataBatch snapshots) throws java.io.IOException {
        long batchProcessingStartTime = System.currentTimeMillis() / 1000;
        var convertedSnapshots = new HashMap<OrderProcessingPriority, List<Snapshot>>();
        long maxLag = 0;
        for (var message : snapshots.getMessages()) {
            OrderProcessingPriority priority;
            if (message.getSourceId().endsWith(LOW_PRIORITY_SOURCE_ID_SUFFIX)) {
                priority = OrderProcessingPriority.LOW;
            }
            else {
                priority = OrderProcessingPriority.REGULAR;
            }
            List<Snapshot> snapshotGroup = convertedSnapshots.computeIfAbsent(priority, (key) -> new ArrayList<>());
            for (var convertedSnapshot : snapshotDecoder.decode(message.getBytes())) {

                var updatedAt = convertedSnapshot.getUpdatedAt();
                var snapshotLag = batchProcessingStartTime - updatedAt;
                readLagDistribution.record(Duration.ofSeconds(snapshotLag));

                if (snapshotLag > maxLag) {
                    maxLag = snapshotLag;
                }

                var snapshotHash = convertedSnapshot.getHash();
                if (convertedSnapshot.getHash() == null) {
                    // Logbroker may read snapshots beyond requested window at reading start. Some of that snapshots have no
                    // hash due to collectors doesn't calc hash earlier
                    // TODO: add monitoring or something to be sure we don't receive such snapshots after reading start
                    continue;
                }

                if (snapshotHash.equals(BREAKER_HASH)) {
                    log.info("Got breaker snapshot");
                    processingFinished = true;
                    break;
                }
                snapshotGroup.add(convertedSnapshot);
            }
        }
        log.debug("batch lag max {}", Duration.ofSeconds(maxLag));
        return convertedSnapshots;
    }

    private HashMap<OrderKey, Snapshot> getProcessedSnapshots(
            Map<OrderKey, List<Snapshot>> snapshots
    ) throws Exception {
        var snapshotsToLookup = new ArrayList<Snapshot>();
        for (var group : snapshots.values()) {
            snapshotsToLookup.add(group.get(0));
        }

        var processedSnapshots = new HashMap<OrderKey, Snapshot>();
        for (var snapshot : processedSnapshotsClient.lookup(snapshotsToLookup)) {
            processedSnapshots.put(new OrderKey(snapshot), snapshot);
        }
        return processedSnapshots;
    }

    public static DeduplicatedSnapshots getDeduplicatedSnapshots(
            Map<OrderKey, List<Snapshot>> snapshots,
            Map<OrderKey, Snapshot> processedSnapshots,
            ThrowingFunction<List<OrderKey>,List<Snapshot>> snapshotReader
    ) throws Exception {
        int newSnapshotCount = 0;
        int oldSnapshotCount = 0;
        int deduplicatedSnapshotCount = 0;

        var snapshotsNotUniqueKey = new ArrayList<SnapshotError>();
        var processedSnapshotsToUpdate = new ArrayList<Snapshot>();
        var ordersToUpdate = new ArrayList<OrderKey>();

        var newSnapshots = new HashMap<OrderKey, List<Snapshot>>();
        var snapshotsToSelect = new ArrayList<OrderKey>();
        for (var item : snapshots.entrySet()) {
            var key = item.getKey();
            var group = item.getValue();

            var previousSnapshot = processedSnapshots.get(key);

            long processedSnapshotUpdatedAt = 0;
            if (previousSnapshot != null) {
                processedSnapshotUpdatedAt = previousSnapshot.getUpdatedAt();
            }

            if (group.get(0).getUpdatedAt() >= processedSnapshotUpdatedAt) {
                // only new snapshots
                newSnapshotCount += group.size();
                var deduplicatedGroup = getDeduplicatedGroup(group, previousSnapshot);
                var deduplicatedGroupSnapshots = deduplicatedGroup.getSnapshots();
                if (!deduplicatedGroupSnapshots.isEmpty()) {
                    newSnapshots.put(key, deduplicatedGroupSnapshots);
                    processedSnapshotsToUpdate.add(deduplicatedGroupSnapshots.get(deduplicatedGroupSnapshots.size() - 1));
                    ordersToUpdate.add(key);
                    deduplicatedSnapshotCount += group.size();
                }
                snapshotsNotUniqueKey.addAll(deduplicatedGroup.getSnapshotsNotUniqueKey());
            } else {
                // some snapshots are old
                oldSnapshotCount += group.size();
                snapshotsToSelect.add(key);
            }
        }

        var existingSnapshots = getGroupedSnapshots(snapshotReader.apply(snapshotsToSelect));

        var snapshotsToDelete = new ArrayList<Snapshot>();

        for (var item : existingSnapshots.entrySet()) {
            var key = item.getKey();
            var existingGroup = getExistingGroup(key, existingSnapshots);

            var group = new ArrayList<>(existingGroup);
            group.addAll(Optional.ofNullable(snapshots.get(key)).orElse(new ArrayList<>()));
            sortGroup(group);
            var deduplicatedGroup = getDeduplicatedGroup(group, null);
            if (deduplicatedGroup.getSnapshots().isEmpty()) {
                continue;
            }
            var allSnapshots = new HashMap<SnapshotKey, Snapshot>();
            for (var snapshot : group) {
                allSnapshots.put(SnapshotKey.fromSnapshot(snapshot), snapshot);
            }
            var deduplicatedSnapshotKeys = deduplicatedGroup.getSnapshots().stream()
                    .map(SnapshotKey::fromSnapshot)
                    .collect(Collectors.toSet());
            var existingSnapshotKeys = existingGroup.stream()
                    .map(SnapshotKey::fromSnapshot)
                    .collect(Collectors.toSet());

            boolean orderIsChanged = false;
            var snapshotKeysToAdd = new HashSet<>(deduplicatedSnapshotKeys);
            snapshotKeysToAdd.removeAll(existingSnapshotKeys);
            if (!snapshotKeysToAdd.isEmpty()) {
                var newGroup = snapshotKeysToAdd.stream().map(allSnapshots::get).collect(Collectors.toList());
                newSnapshots.put(key, newGroup);
                var lastSnapshotToAdd = newGroup.get(newGroup.size() - 1);

                var previousSnapshot = processedSnapshots.get(key);
                if (previousSnapshot == null  || lastSnapshotToAdd.getUpdatedAt() > previousSnapshot.getUpdatedAt() || lastSnapshotToAdd.getUpdatedAt() == previousSnapshot.getUpdatedAt() && !lastSnapshotToAdd.getHash().equals(previousSnapshot.getHash())) {
                    processedSnapshotsToUpdate.add(lastSnapshotToAdd);
                }

                deduplicatedSnapshotCount += snapshotKeysToAdd.size();
                orderIsChanged = true;
            }

            var snapshotKeysToDelete = new HashSet<>(existingSnapshotKeys);
            snapshotKeysToDelete.removeAll(deduplicatedSnapshotKeys);
            if (!snapshotKeysToDelete.isEmpty()) {
                snapshotsToDelete.addAll(snapshotKeysToDelete.stream().map(allSnapshots::get).collect(Collectors.toList()));
                orderIsChanged = true;
            }
            if (orderIsChanged) {
                ordersToUpdate.add(key);
            }
        }

        log.info(
                "snapshots: {} new, {} old, {} deduplicated, {} not unique",
                newSnapshotCount,
                oldSnapshotCount,
                deduplicatedSnapshotCount,
                snapshotsNotUniqueKey.size()
        );

        return new DeduplicatedSnapshots(
                snapshotsToDelete,
                newSnapshots,
                processedSnapshotsToUpdate,
                snapshotsNotUniqueKey,
                ordersToUpdate,
                deduplicatedSnapshotCount,
                newSnapshotCount,
                oldSnapshotCount
        );
    }

    private static List<Snapshot> getExistingGroup(OrderKey key, Map<OrderKey, List<Snapshot>> existingSnapshots) {
        var existingGroup = existingSnapshots.get(key);
        if (existingGroup == null) {
            existingGroup = new ArrayList<>();
        }
        return new ArrayList<>(existingGroup);
    }

    private static DeduplicatedGroup getDeduplicatedGroup(List<Snapshot> group, Snapshot compareTo) {
        var deduplicatedGroup = new ArrayList<Snapshot>();
        var snapshotsNotUniqueKey = new HashSet<SnapshotError>();
        var previousSnapshot = compareTo;
        for (var snapshot : group) {
            if (snapshotsDiffers(previousSnapshot, snapshot)) {
                deduplicatedGroup.add(snapshot);

                if (updatedAtEquals(previousSnapshot, snapshot)) {
                    snapshotsNotUniqueKey.add(new SnapshotError(snapshot, UNIQUE_ERROR));
                }

                previousSnapshot = snapshot;
            }
        }
        return new DeduplicatedGroup(deduplicatedGroup, new ArrayList<>(snapshotsNotUniqueKey));
    }

    private static boolean snapshotsDiffers(Snapshot previous, Snapshot next) {
        if (previous == null) {
            return true;
        }
        return !previous.getHash().equals(next.getHash());
    }

    private static boolean updatedAtEquals(Snapshot previous, Snapshot next) {
        if (previous == null) {
            return false;
        }
        return previous.getUpdatedAt() == next.getUpdatedAt() && !previous.getHash().equals(next.getHash());
    }

    private boolean hasUpdates(DeduplicatedSnapshots snapshots) {
        return !(snapshots.getNewSnapshots().isEmpty() &&
                snapshots.getSnapshotsToDelete().isEmpty() &&
                snapshots.getSnapshotsNotUniqueKey().isEmpty());
    }

    private void calculateMetrics(DeduplicatedSnapshots deduplicatedSnapshots) {
        deduplicatedSnapshotsCounter.increment(deduplicatedSnapshots.getDeduplicatedCount());
        incomingNewSnapshotsCounter.increment(deduplicatedSnapshots.getNewCount());
        incomingOldSnapshotsCounter.increment(deduplicatedSnapshots.getOldCount());
        snapshotUniqueKeyErrorsCounter.increment(deduplicatedSnapshots.getSnapshotsNotUniqueKey().size());
    }

    private List<Snapshot> getSnapshotsFlat(Map<OrderKey, List<Snapshot>> snapshots) {
        return snapshots
                .values()
                .stream()
                .flatMap(List::stream)
                .collect(Collectors.toList());
    }

    private List<Snapshot> getLastSnapshots(DeduplicatedSnapshots deduplicatedSnapshots) {
        return deduplicatedSnapshots.getNewSnapshots().values().stream()
                .map(l -> l.get(l.size() - 1))
                .collect(Collectors.toList());
    }
}
