package ru.yandex.intranet.d.loaders;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.yandex.ydb.table.transaction.TransactionMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.cache.CacheMono;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

import ru.yandex.intranet.d.datasource.model.YdbSession;
import ru.yandex.intranet.d.datasource.model.YdbTableClient;
import ru.yandex.intranet.d.datasource.model.YdbTxSession;
import ru.yandex.intranet.d.model.TenantId;

/**
 * Abstract by id loader implementation.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public class ByIdLoader<K, V> {

    private static final Logger LOG = LoggerFactory.getLogger(ByIdLoader.class);

    private final YdbTableClient ydbTableClient;
    private final String cacheValueDisplayName;
    private final int refreshPageSize;
    private final Cache<CacheKey<K>, V> presentById;
    private final Cache<CacheKey<K>, K> missingById;
    private final BiFunction<YdbTxSession, List<Tuple2<K, TenantId>>, Mono<List<V>>> findByIds;
    private final BiFunction<YdbTxSession, Tuple2<K, TenantId>, Mono<Optional<V>>> findById;
    private final Function<V, K> idFromValue;
    private final Function<V, TenantId> tenantIdFromValue;

    @SuppressWarnings("ParameterNumber")
    public ByIdLoader(long maxSize,
                      long maxMissingSize,
                      Duration expireAfterWrite,
                      Duration expireAfterWriteMissing,
                      YdbTableClient ydbTableClient,
                      String cacheValueDisplayName,
                      int refreshPageSize,
                      BiFunction<YdbTxSession, List<Tuple2<K, TenantId>>, Mono<List<V>>> findByIds,
                      BiFunction<YdbTxSession, Tuple2<K, TenantId>, Mono<Optional<V>>> findById,
                      Function<V, K> idFromValue,
                      Function<V, TenantId> tenantIdFromValue) {
        this.ydbTableClient = ydbTableClient;
        this.cacheValueDisplayName = cacheValueDisplayName;
        this.refreshPageSize = refreshPageSize;
        this.findByIds = findByIds;
        this.findById = findById;
        this.idFromValue = idFromValue;
        this.tenantIdFromValue = tenantIdFromValue;
        this.presentById = CacheBuilder.newBuilder()
                .maximumSize(maxSize)
                .expireAfterWrite(expireAfterWrite)
                .build();
        this.missingById = CacheBuilder.newBuilder()
                .maximumSize(maxMissingSize)
                .expireAfterWrite(expireAfterWriteMissing)
                .build();
    }

    public Mono<Optional<V>> getById(YdbTxSession session, K id, TenantId tenantId) {
        return CacheMono.lookup(this::getFromCacheById, new CacheKey<>(id, tenantId))
                .onCacheMissResume(() -> findById.apply(session, Tuples.of(id, tenantId)))
                .andWriteWith(this::putByIdToCache);
    }

    public Mono<List<V>> getByIds(YdbTxSession session, List<Tuple2<K, TenantId>> ids) {
        if (ids.isEmpty()) {
            return Mono.just(List.of());
        }
        List<CacheKey<K>> keys = ids.stream().map(t -> new CacheKey<>(t.getT1(), t.getT2()))
                .collect(Collectors.toList());
        return CacheMono.lookup(this::getFromCacheByIds, keys)
                .onCacheMissResume(() -> loadByIds(session, keys))
                .andWriteWith(this::putByIdsToCache);
    }

    public Mono<Optional<V>> getByIdImmediate(K id, TenantId tenantId) {
        return CacheMono.lookup(this::getFromCacheById, new CacheKey<>(id, tenantId))
                .onCacheMissResume(() -> loadByIdImmediate(id, tenantId))
                .andWriteWith(this::putByIdToCache);
    }

    public Mono<List<V>> getByIdsImmediate(List<Tuple2<K, TenantId>> ids) {
        if (ids.isEmpty()) {
            return Mono.just(List.of());
        }
        List<CacheKey<K>> keys = ids.stream().map(t -> new CacheKey<>(t.getT1(), t.getT2()))
                .collect(Collectors.toList());
        return CacheMono.lookup(this::getFromCacheByIds, keys)
                .onCacheMissResume(() -> loadByIdsImmediate(keys))
                .andWriteWith(this::putByIdsToCache);
    }

    public void update(V value) {
        CacheKey<K> cacheKey = new CacheKey<>(idFromValue.apply(value), tenantIdFromValue.apply(value));
        if (presentById.getIfPresent(cacheKey) != null) {
            presentById.put(cacheKey, value);
        }
        if (missingById.getIfPresent(cacheKey) != null) {
            missingById.invalidate(cacheKey);
        }
    }

    public void update(V value, BiFunction<V, V, V> merger) {
        CacheKey<K> cacheKey = new CacheKey<>(idFromValue.apply(value), tenantIdFromValue.apply(value));
        V oldValue = presentById.getIfPresent(cacheKey);
        if (oldValue != null) {
            presentById.put(cacheKey, merger.apply(oldValue, value));
        }
        if (missingById.getIfPresent(cacheKey) != null) {
            missingById.invalidate(cacheKey);
        }
    }

    public void refresh(K key, TenantId tenantId) {
        List<Tuple2<K, TenantId>> keyToLoad = List.of(Tuples.of(key, tenantId));
        ydbTableClient.usingSessionMonoRetryable(session ->
                doRefreshByIdPage(session, keyToLoad)
        ).doOnError(e -> LOG.error("Failed to refresh " + cacheValueDisplayName + " cache", e))
                .onErrorResume(e -> Mono.empty()).block();
    }

    public Mono<List<CacheKey<K>>> asyncRefresh(K key, TenantId tenantId) {
        List<Tuple2<K, TenantId>> keyToLoad = List.of(Tuples.of(key, tenantId));
        return ydbTableClient.usingSessionMonoRetryable(session ->
                doRefreshByIdPage(session, keyToLoad)
        ).doOnError(e -> LOG.error("Failed to refresh " + cacheValueDisplayName + " cache", e));
    }

    public void refresh(K key, TenantId tenantId, YdbTxSession session) {
        List<Tuple2<K, TenantId>> keyToLoad = List.of(Tuples.of(key, tenantId));
        doRefreshByIdPage(session, keyToLoad)
                .doOnError(e -> LOG.error("Failed to refresh " + cacheValueDisplayName + " cache", e))
                .onErrorResume(e -> Mono.empty()).block();
    }

    public Mono<List<CacheKey<K>>> asyncRefresh(K key, TenantId tenantId, YdbTxSession session) {
        List<Tuple2<K, TenantId>> keyToLoad = List.of(Tuples.of(key, tenantId));
        return doRefreshByIdPage(session, keyToLoad)
                .doOnError(e -> LOG.error("Failed to refresh " + cacheValueDisplayName + " cache", e));
    }

    public void refresh() {
        Set<CacheKey<K>> cachedKeysById = new HashSet<>(presentById.asMap().keySet());
        Set<CacheKey<K>> cachedMissingKeysById = new HashSet<>(missingById.asMap().keySet());
        if (cachedKeysById.isEmpty() && cachedMissingKeysById.isEmpty()) {
            return;
        }
        List<List<Tuple2<K, TenantId>>> valuesToLoadById = Lists.partition(Sets
                .union(cachedKeysById, cachedMissingKeysById)
                .stream().map(k -> Tuples.of(k.getId(), k.getTenantId()))
                .collect(Collectors.toList()), refreshPageSize);
        ydbTableClient.usingSessionMonoRetryable(session ->
                doRefreshById(session, valuesToLoadById)
        ).doOnError(e -> LOG.error("Failed to refresh " + cacheValueDisplayName + " cache", e))
                .onErrorResume(e -> Mono.empty()).block();
    }

    private Mono<Void> doRefreshById(YdbSession session,
                                     List<List<Tuple2<K, TenantId>>> keysToLoad) {
        return Flux.fromIterable(keysToLoad)
                .concatMap(page -> doRefreshByIdPage(session, page))
                .flatMapIterable(l -> l)
                .collectList()
                .doOnSuccess(missingById::invalidateAll).then();
    }

    private Mono<List<CacheKey<K>>> doRefreshByIdPage(YdbSession session,
                                                      List<Tuple2<K, TenantId>> page) {
        return findByIds.apply(session.asTxCommitRetryable(TransactionMode.STALE_READ_ONLY), page)
                .doOnSuccess(this::putToCacheById)
                .map(l -> l.stream().map(e -> new CacheKey<>(idFromValue.apply(e), tenantIdFromValue.apply(e)))
                        .collect(Collectors.toList()));
    }

    private Mono<List<CacheKey<K>>> doRefreshByIdPage(YdbTxSession session,
                                                      List<Tuple2<K, TenantId>> page) {
        return findByIds.apply(session, page)
                .doOnSuccess(this::putToCacheById)
                .map(l -> l.stream().map(e -> new CacheKey<>(idFromValue.apply(e), tenantIdFromValue.apply(e)))
                        .collect(Collectors.toList()));
    }

    private void putToCacheById(List<V> values) {
        values.forEach(value -> {
            CacheKey<K> key = new CacheKey<>(idFromValue.apply(value), tenantIdFromValue.apply(value));
            presentById.put(key, value);
        });
    }

    private Mono<Signal<? extends Optional<V>>> getFromCacheById(CacheKey<K> key) {
        V value = presentById.getIfPresent(key);
        if (value != null) {
            return Mono.just(Signal.next(Optional.of(value)));
        }
        K missingId = missingById.getIfPresent(key);
        if (missingId != null) {
            return Mono.just(Signal.next(Optional.empty()));
        }
        return Mono.empty();
    }

    private Mono<Void> putByIdToCache(CacheKey<K> key, Signal<? extends Optional<V>> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            Optional<V> valueToPut = value.get();
            if (valueToPut != null && valueToPut.isPresent()) {
                presentById.put(key, valueToPut.get());
            } else {
                missingById.put(key, key.getId());
            }
        });
    }

    private Mono<Signal<? extends List<V>>> getFromCacheByIds(List<CacheKey<K>> keys) {
        Map<CacheKey<K>, V> presentValues = presentById.getAllPresent(keys);
        Map<CacheKey<K>, K> presentMissingValues = missingById.getAllPresent(keys);
        if (keys.stream().allMatch(k -> presentValues.containsKey(k) || presentMissingValues.containsKey(k))) {
            return Mono.just(Signal.next(new ArrayList<>(presentValues.values())));
        }
        return Mono.empty();
    }

    private Mono<List<V>> loadByIds(YdbTxSession session, List<CacheKey<K>> keys) {
        Map<CacheKey<K>, V> presentValues = presentById.getAllPresent(keys);
        Map<CacheKey<K>, K> presentMissingValues = missingById.getAllPresent(keys);
        List<Tuple2<K, TenantId>> idsToLoad = keys.stream()
                .filter(k -> !presentValues.containsKey(k) && !presentMissingValues.containsKey(k))
                .map(k -> Tuples.of(k.getId(), k.getTenantId())).collect(Collectors.toList());
        return findByIds.apply(session, idsToLoad)
                .map(l -> {
                    if (presentValues.isEmpty()) {
                        return l;
                    }
                    List<V> result = new ArrayList<>(l);
                    result.addAll(presentValues.values());
                    return result;
                });
    }

    private Mono<Void> putByIdsToCache(List<CacheKey<K>> keys,
                                       Signal<? extends List<V>> values) {
        return Mono.fromRunnable(() -> {
            if (!values.hasValue()) {
                return;
            }
            List<V> valuesToPut = values.get();
            if (valuesToPut == null) {
                return;
            }
            Map<CacheKey<K>, V> valuesByKey = valuesToPut.stream()
                    .collect(Collectors.toMap(u -> new CacheKey<>(idFromValue.apply(u), tenantIdFromValue.apply(u)),
                            u -> u, (l, r) -> l));
            Map<CacheKey<K>, V> presentValues = presentById.getAllPresent(keys);
            Map<CacheKey<K>, K> presentMissingValues = missingById.getAllPresent(keys);
            Set<CacheKey<K>> existingKeys = keys.stream()
                    .filter(valuesByKey::containsKey).collect(Collectors.toSet());
            Set<CacheKey<K>> missingKeys = keys.stream()
                    .filter(k -> !valuesByKey.containsKey(k)).collect(Collectors.toSet());
            Set<CacheKey<K>> keysToAdd = Sets.difference(existingKeys, presentValues.keySet());
            Set<CacheKey<K>> keysToAddAsMissing = Sets.difference(missingKeys,
                    presentMissingValues.keySet());
            keysToAdd.forEach(k -> presentById.put(k, valuesByKey.get(k)));
            keysToAddAsMissing.forEach(k -> missingById.put(k, k.getId()));
        });
    }

    private Mono<Optional<V>> loadByIdImmediate(K id, TenantId tenantId) {
        return ydbTableClient.usingSessionMonoRetryable(session ->
                findById.apply(session.asTxCommitRetryable(TransactionMode.STALE_READ_ONLY),
                        Tuples.of(id, tenantId)));
    }

    private Mono<List<V>> loadByIdsImmediate(List<CacheKey<K>> keys) {
        Map<CacheKey<K>, V> presentValues = presentById.getAllPresent(keys);
        Map<CacheKey<K>, K> presentMissingValues = missingById.getAllPresent(keys);
        List<Tuple2<K, TenantId>> idsToLoad = keys.stream()
                .filter(k -> !presentValues.containsKey(k) && !presentMissingValues.containsKey(k))
                .map(k -> Tuples.of(k.getId(), k.getTenantId())).collect(Collectors.toList());
        return ydbTableClient.usingSessionMonoRetryable(session ->
                findByIds.apply(session.asTxCommitRetryable(TransactionMode.STALE_READ_ONLY), idsToLoad)
                        .map(l -> {
                            if (presentValues.isEmpty()) {
                                return l;
                            }
                            List<V> result = new ArrayList<>(l);
                            result.addAll(presentValues.values());
                            return result;
                        })
        );
    }

}
