package ru.yandex.infra.stage.cache;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.collect.Maps;
import com.google.protobuf.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;
import ru.yandex.infra.controller.metrics.NamespacedGaugeRegistry;

public class StorageBasedCache<TValue, TProtoValue extends Message> implements Cache<TValue> {
    private static final Logger LOG = LoggerFactory.getLogger(StorageBasedCache.class);

    static final String METRIC_COUNT = "count";
    static final String METRIC_PUT = "put";
    static final String METRIC_REMOVE = "remove";
    static final String METRIC_GET_HIT = "get_hit";
    static final String METRIC_GET_MISS = "get_miss";

    private final static Set<CachedObjectType> initializedMetricsByType = ConcurrentHashMap.newKeySet();
    private final AtomicLong metricPutCount = new AtomicLong();
    private final AtomicLong metricRemoveCount = new AtomicLong();
    private final AtomicLong metricGetHitCount = new AtomicLong();
    private final AtomicLong metricGetMissCount = new AtomicLong();

    private final CachedObjectType<TValue, TProtoValue> type;
    private final Map<String, TValue> values;
    private final CacheStorage<TProtoValue> storage;

    public StorageBasedCache(CachedObjectType<TValue, TProtoValue> type, CacheStorageFactory storageFactory, GaugeRegistry gaugeRegistry) {
        this.type = type;
        this.storage = storageFactory.createStorage(type);

        LOG.debug("Loading all {} cache records", type.getName());
        try {
            long startTimeMillis = System.currentTimeMillis();
            Map<String, TValue> valuesFromStorage = storage.init()
                    .thenCompose(x -> storage.read())
                    .thenApply(map -> Maps.transformValues(map, type.getFromProto()::apply))
                    .get();

            long ms = System.currentTimeMillis() - startTimeMillis;
            LOG.info("Loaded {} {} cache records for {} ms ({} nodes/s)", valuesFromStorage.size(), type.getName(),
                    ms,
                    1000L * valuesFromStorage.size() / (ms+1));

            values = new ConcurrentHashMap<>(valuesFromStorage);
        } catch (InterruptedException|ExecutionException e) {
            throw new RuntimeException(String.format("Failed to load %s cache values", type.getName()), e);
        }

        //initializing only first instance for type.
        //All other replicas will be temporarily used for cache.export
        if (initializedMetricsByType.add(type)) {
            GaugeRegistry registry = new NamespacedGaugeRegistry(gaugeRegistry, "cache." + type.getName());
            registry.add(METRIC_COUNT, new GolovanableGauge<>(values::size, "axxx"));
            registry.add(METRIC_PUT, new GolovanableGauge<>(metricPutCount::get, "dmmm"));
            registry.add(METRIC_REMOVE, new GolovanableGauge<>(metricRemoveCount::get, "dmmm"));
            registry.add(METRIC_GET_HIT, new GolovanableGauge<>(metricGetHitCount::get, "dmmm"));
            registry.add(METRIC_GET_MISS, new GolovanableGauge<>(metricGetMissCount::get, "dmmm"));
        }
    }

    @Override
    public Map<String, TValue> getAll() {
        return values;
    }

    @Override
    public Optional<TValue> get(String key) {
        final TValue result = values.get(key);
        if (result != null) {
            metricGetHitCount.incrementAndGet();
            return Optional.of(result);
        }

        metricGetMissCount.incrementAndGet();
        return Optional.empty();
    }

    @Override
    public CompletableFuture<?> put(String key, TValue value) {
        metricPutCount.incrementAndGet();
        TValue result = values.put(key, value);
        final boolean updated = !value.equals(result);
        LOG.info("[{}] put into cache (added={}): {}", type.getName(), updated, key);
        if (updated) {
            return storage.write(key, type.getToProto().apply(value));
        }
        return CompletableFuture.completedFuture(null);
    }

    @Override
    public CompletableFuture<?> remove(String key) {
        metricRemoveCount.incrementAndGet();
        TValue result = values.remove(key);
        LOG.info("[{}] remove from cache (removed={}): {}", type.getName(), result != null, key);
        return storage.remove(key);
    }
}
