package ru.yandex.infra.controller.yp;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import com.google.common.collect.Iterables;
import com.google.protobuf.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Try;
import ru.yandex.infra.controller.dto.SchemaMeta;
import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;

import static java.util.Collections.emptyMap;

public class YpObjectsCache<Meta extends SchemaMeta, Spec extends Message, Status extends Message> {
    private static final Logger LOG = LoggerFactory.getLogger(YpObjectsCache.class);

    static final String METRIC_UPDATED_OBJECTS_COUNT = "updated_objects_count";
    static final String METRIC_CURRENT_BATCH_OFFSET = "current_batch_offset";
    static final String METRIC_FULL_RELOAD_COUNT = "full_reload_count";
    static final String METRIC_WATCH_OBJECTS_TIME = "watch_objects_time_ms";
    static final String METRIC_WATCH_OBJECTS_ERRORS_COUNT = "watch_objects_errors";
    static final String METRIC_GET_OBJECTS_TIME = "get_objects_time_ms";
    static final String METRIC_GET_OBJECTS_ERRORS_COUNT = "get_objects_errors";
    static final String METRIC_YP_OBJECTS_LOAD_TIME = "yp_objects_load_time_ms";
    static final String METRIC_YP_OBJECTS_LOAD_ERRORS_COUNT = "yp_objects_load_errors";
    static final String METRIC_YP_OBJECTS_COUNT = "yp_objects_count";

    private final AtomicLong metricFullReloadCount = new AtomicLong();
    private final AtomicLong metricUpdatedObjectsCount = new AtomicLong();
    private final AtomicLong metricYpObjectLoadErrorsCount = new AtomicLong();
    private final AtomicLong metricWatchObjectsErrorsCount = new AtomicLong();
    private final AtomicLong metricGetObjectsErrorsCount = new AtomicLong();
    private volatile Integer metricCurrentBatchOffset;
    private volatile Long metricWatchObjectsTimeMilliseconds;
    private volatile Long metricGetObjectsTimeMilliseconds;
    private volatile Long metricYpObjectsLoadTimeMilliseconds;
    private volatile Integer metricYpObjectsCount;

    private final YpObjectTransactionalRepository<Meta, Spec, Status> ypRepository;
    private final AtomicReference<SelectedObjects<Meta, Spec, Status>> currentSnapshot = new AtomicReference<>();
    private final YpObjectSettings settings;
    private final Selector selector;
    private final Selector selectorWithLabels;
    private long lastFullReloadTimestamp;

    public YpObjectsCache(YpObjectTransactionalRepository<Meta, Spec, Status> ypRepository,
                          YpObjectSettings settings,
                          GaugeRegistry registry,
                          Selector selector) {
        this.ypRepository = ypRepository;
        this.settings = settings;
        this.selector = selector;
        selectorWithLabels = selector.withLabels();

        registry.add(METRIC_UPDATED_OBJECTS_COUNT, new GolovanableGauge<>(metricUpdatedObjectsCount::get, "dxxm"));
        registry.add(METRIC_CURRENT_BATCH_OFFSET, new GolovanableGauge<>(() -> metricCurrentBatchOffset, "axxx"));
        registry.add(METRIC_FULL_RELOAD_COUNT, new GolovanableGauge<>(metricFullReloadCount::get, "dmmm"));
        registry.add(METRIC_WATCH_OBJECTS_TIME, new GolovanableGauge<>(() -> metricWatchObjectsTimeMilliseconds, "axxx"));
        registry.add(METRIC_WATCH_OBJECTS_ERRORS_COUNT, new GolovanableGauge<>(metricWatchObjectsErrorsCount::get, "dmmm"));
        registry.add(METRIC_GET_OBJECTS_TIME, new GolovanableGauge<>(() -> metricGetObjectsTimeMilliseconds, "axxx"));
        registry.add(METRIC_GET_OBJECTS_ERRORS_COUNT, new GolovanableGauge<>(metricGetObjectsErrorsCount::get, "dmmm"));
        registry.add(METRIC_YP_OBJECTS_LOAD_TIME, new GolovanableGauge<>(() -> metricYpObjectsLoadTimeMilliseconds, "axxx"));
        registry.add(METRIC_YP_OBJECTS_LOAD_ERRORS_COUNT, new GolovanableGauge<>(metricYpObjectLoadErrorsCount::get, "dmmm"));
        registry.add(METRIC_YP_OBJECTS_COUNT, new GolovanableGauge<>(() -> metricYpObjectsCount, "axxx"));
    }

    private CompletableFuture<Map<String, Try<YpObject<Meta, Spec, Status>>>> getInitialSnapshot(
            SelectedObjects<Meta, Spec, Status> snapshot, Optional<Long> timestamp) {
        if (settings.isWatchesEnabled()) {
            metricFullReloadCount.incrementAndGet();
            LOG.info("[{}] Loading initial snapshot from YP...", ypRepository);
        }
        long startTimeMillis = System.currentTimeMillis();
        var selectRequest = timestamp.isPresent() ? ypRepository.selectObjects(selector, emptyMap(), timestamp.get()) :
                ypRepository.selectObjects(selector, emptyMap());
        return selectRequest.thenApply(selectedObjects -> {
            tryUpdateCurrentSnapshot(snapshot, selectedObjects);
            var now = System.currentTimeMillis();
            lastFullReloadTimestamp = now;
            if (settings.isWatchesEnabled()) {
                LOG.info("[{}] Loaded {} objects in initial snapshot with YP timestamp {} in {} ms",
                        ypRepository, selectedObjects.getObjects().size(), selectedObjects.getTimestamp(),
                        now - startTimeMillis);
            }
            return selectedObjects.getObjects();
        });
    }

    //It's not expected to call this method in parallel,
    //  but it should not fail anyway, currentSnapshot is updated atomically.
    public CompletableFuture<Map<String, Try<YpObject<Meta, Spec, Status>>>> selectObjects(Optional<Long> timestamp) {
        long startTimeMillis = System.currentTimeMillis();
        return reloadObjects(timestamp)
                .whenComplete((objects, error) -> {
                    metricYpObjectsLoadTimeMilliseconds = System.currentTimeMillis() - startTimeMillis;
                    if (error != null) {
                        metricYpObjectLoadErrorsCount.incrementAndGet();
                    }
                });
    }

    private boolean isFullReloadTimeoutExpired() {
        return !settings.getFullReloadInterval().isZero() &&
                (System.currentTimeMillis() - lastFullReloadTimestamp) > settings.getFullReloadInterval().toMillis();
    }

    private CompletableFuture<Map<String, Try<YpObject<Meta, Spec, Status>>>> reloadObjects(Optional<Long> timestamp) {
        SelectedObjects<Meta, Spec, Status> snapshot = currentSnapshot.get();
        if (snapshot == null || !settings.isWatchesEnabled() || isFullReloadTimeoutExpired()) {
            return getInitialSnapshot(snapshot, timestamp);
        }

        long startTimeMillis = System.currentTimeMillis();
        return ypRepository.watchObjects(snapshot.getTimestamp(), Optional.empty())
                .orTimeout(settings.getWatchesTimeout().toNanos(), TimeUnit.NANOSECONDS)
                .whenComplete((x, error) -> {
                    metricWatchObjectsTimeMilliseconds = System.currentTimeMillis() - startTimeMillis;
                    if (error != null) {
                        metricWatchObjectsErrorsCount.incrementAndGet();
                    }
                })
                .thenCompose(watchedObjects -> {
                    metricUpdatedObjectsCount.addAndGet(watchedObjects.getEvents().size());
                    Map<String, Try<YpObject<Meta, Spec, Status>>> oldObjects = snapshot.getObjects();
                    if (watchedObjects.getEvents().size() == 0) {
                        var snapshotWithUpdatedTimestamp = new SelectedObjects<>(oldObjects, watchedObjects.getTimestamp());
                        tryUpdateCurrentSnapshot(snapshot, snapshotWithUpdatedTimestamp);
                        return CompletableFuture.completedFuture(snapshotWithUpdatedTimestamp.getObjects());
                    }

                    final Set<String> updatedObjectIds = watchedObjects.getEvents().keySet();

                    Map<String, Try<YpObject<Meta, Spec, Status>>> mapForResultsAccumulation = oldObjects
                            //Take all objects from previous snapshot except updated objects
                            .entrySet()
                            .stream()
                            .filter(e -> !updatedObjectIds.contains(e.getKey()))
                            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

                    LOG.debug("[{}] Found {} updated objects in {} ms (yp timestamp {})",
                            ypRepository, updatedObjectIds.size(), metricWatchObjectsTimeMilliseconds, watchedObjects.getTimestamp());
                    long getObjectsStartTimeMillis = System.currentTimeMillis();
                    return loadObjects(updatedObjectIds, watchedObjects.getTimestamp())
                            .whenComplete((x, error) -> {
                                metricGetObjectsTimeMilliseconds = System.currentTimeMillis() - getObjectsStartTimeMillis;
                                if (error != null) {
                                    metricGetObjectsErrorsCount.incrementAndGet();
                                }
                            })
                            .thenApply(updatedObjects -> {
                                LOG.debug("[{}] Reloaded {} of {} objects in {} ms",
                                        ypRepository, updatedObjects.size(), updatedObjectIds.size(), metricGetObjectsTimeMilliseconds);
                                mapForResultsAccumulation.putAll(updatedObjects);
                                var newSnapshot = new SelectedObjects<>(mapForResultsAccumulation, watchedObjects.getTimestamp());
                                tryUpdateCurrentSnapshot(snapshot, newSnapshot);
                                return mapForResultsAccumulation;
                            });
                })
                .exceptionally(error -> {
                    if (error != null) {
                        LOG.warn("[{}] Resetting cache after error: {}", ypRepository, error);
                        tryUpdateCurrentSnapshot(snapshot, null);
                    }
                    return null;
                })
                .thenCompose(map -> {
                    if (map == null) {
                        return getInitialSnapshot(null, Optional.empty());
                    }
                    return CompletableFuture.completedFuture(map);
                });
    }

    private void tryUpdateCurrentSnapshot(SelectedObjects<Meta, Spec, Status> lastSeenSnapshot,
                                          SelectedObjects<Meta, Spec, Status> newValue) {
        if (!currentSnapshot.compareAndSet(lastSeenSnapshot, newValue)) {
            throw new RuntimeException(String.format("Failed to update objects cache for %s. Objects was already updated from another thread.", ypRepository));
        }

        if (newValue != null && newValue.getObjects() != null) {
            metricYpObjectsCount = newValue.getObjects().size();
        }
    }

    private CompletableFuture<Map<String, Try<YpObject<Meta, Spec, Status>>>> loadObjects(Set<String> updatedObjectIds, long timestamp) {
        Map<String, Try<YpObject<Meta, Spec, Status>>> mapForResultsAccumulation = new HashMap<>();
        CompletableFuture<?> chainOfYpGetObjectCalls = CompletableFuture.completedFuture(null);
        int offset = 0;
        for (List<String> batch : Iterables.partition(updatedObjectIds, settings.getGetObjectsBatchSize())) {
            int currentOffset = offset;
            chainOfYpGetObjectCalls = chainOfYpGetObjectCalls.thenCompose(x -> loadNextBatchOfObjects(currentOffset, batch, mapForResultsAccumulation, timestamp));
            offset += batch.size();
        }
        return chainOfYpGetObjectCalls
                .thenApply(x -> mapForResultsAccumulation)
                .whenComplete((x, error) -> metricCurrentBatchOffset = 0);
    }

    private CompletableFuture<?> loadNextBatchOfObjects(int offset,
                                                        List<String> ids,
                                                        Map<String, Try<YpObject<Meta, Spec, Status>>> mapForResultsAccumulation,
                                                        long timestamp) {
        metricCurrentBatchOffset = offset;
        return ypRepository.getObjects(ids, selectorWithLabels, timestamp)
                .thenAccept(list -> {
                    for (int i = 0; i < list.size(); i++) {
                        String id = ids.get(i);
                        list.get(i).ifPresent(ypObject -> mapForResultsAccumulation.put(id, Try.success(ypObject)));
                    }
                });
    }
}
