package ru.yandex.qe.dispenser.domain.dao.person;

import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import javax.inject.Inject;

import com.google.common.base.Stopwatch;
import com.google.common.base.Ticker;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.UncheckedExecutionException;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.qe.dispenser.domain.Person;
import ru.yandex.qe.dispenser.domain.YaGroup;
import ru.yandex.qe.dispenser.domain.index.LongIndexBase;
import ru.yandex.qe.dispenser.solomon.SolomonHolder;
import ru.yandex.monlib.metrics.histogram.Histograms;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.monlib.metrics.primitives.Histogram;
import ru.yandex.monlib.metrics.primitives.Rate;
import ru.yandex.monlib.metrics.registry.MetricRegistry;

public class StaffCache {

    private static final Logger LOG = LoggerFactory.getLogger(StaffCache.class);
    public static final String SENSOR_PREFIX = "staff_cache_refresh_task.";
    public static final String ELAPSED_TIME_SENSOR = SENSOR_PREFIX + "elapsed_time";
    public static final String ERROR_RATE_SENSOR = SENSOR_PREFIX + "error_rate";
    public static final String LAST_START_SENSOR = SENSOR_PREFIX + "time_since_last_start";
    public static final String LAST_SUCCESS_SENSOR = SENSOR_PREFIX + "time_since_last_success_end";

    @NotNull
    private final PersonDao personDao;
    @NotNull
    private final PersonGroupMembershipDao personGroupMembershipDao;
    @NotNull
    private final LoadingCache<Long, Person> personsById;
    @NotNull
    private final LoadingCache<Long, Person> personsByUid;
    @NotNull
    private final LoadingCache<String, Person> personsByLogin;
    @NotNull
    private final LoadingCache<Long, Set<Person>> personsByGroupId;
    @NotNull
    private final LoadingCache<Long, Set<YaGroup>> groupsByPersonId;
    @NotNull
    private final Cache<Long, Boolean> missingIds;
    @NotNull
    private final Cache<Long, Boolean> missingUids;
    @NotNull
    private final Cache<String, Boolean> missingLogins;
    @NotNull
    private final Histogram elapsedTime;
    @NotNull
    private final Rate errorRate;
    @NotNull
    private final Ticker ticker = Ticker.systemTicker();

    private volatile long lastStart;
    private volatile long lastSuccessEnd;

    @Inject
    public StaffCache(@NotNull final PersonDao personDao, @NotNull final PersonGroupMembershipDao personGroupMembershipDao,
                      @NotNull final SolomonHolder solomonHolder) {
        this.personDao = personDao;
        this.personGroupMembershipDao = personGroupMembershipDao;
        this.personsById = CacheBuilder.newBuilder()
                .maximumSize(100000)
                .expireAfterWrite(24, TimeUnit.HOURS)
                .build(new PersonByIdLoader(personDao));
        this.personsByUid = CacheBuilder.newBuilder()
                .maximumSize(100000)
                .expireAfterWrite(24, TimeUnit.HOURS)
                .build(new PersonByUidLoader(personDao));
        this.personsByLogin = CacheBuilder.newBuilder()
                .maximumSize(100000)
                .expireAfterWrite(24, TimeUnit.HOURS)
                .build(new PersonByLoginLoader(personDao));
        this.personsByGroupId = CacheBuilder.newBuilder()
                .maximumSize(100000)
                .expireAfterWrite(30, TimeUnit.MINUTES)
                .build(new PersonsByGroupIdLoader(personGroupMembershipDao, personsById));
        this.groupsByPersonId = CacheBuilder.newBuilder()
                .maximumSize(100000)
                .expireAfterWrite(30, TimeUnit.MINUTES)
                .build(new GroupsByPersonIdLoader(personGroupMembershipDao));
        this.missingIds = CacheBuilder.newBuilder()
                .maximumSize(1000)
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .build();
        this.missingUids = CacheBuilder.newBuilder()
                .maximumSize(1000)
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .build();
        this.missingLogins = CacheBuilder.newBuilder()
                .maximumSize(1000)
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .build();
        final MetricRegistry rootRegistry = solomonHolder.getRootRegistry();
        this.elapsedTime = rootRegistry.histogramRate(ELAPSED_TIME_SENSOR, Labels.of(), Histograms.exponential(22, 2, 1.0d));
        this.errorRate = rootRegistry.rate(ERROR_RATE_SENSOR, Labels.of());
        this.lastStart = TimeUnit.NANOSECONDS.toMillis(ticker.read());
        rootRegistry.lazyGaugeInt64(LAST_START_SENSOR, Labels.of(), () -> TimeUnit.NANOSECONDS.toMillis(ticker.read()) - lastStart);
        this.lastSuccessEnd = TimeUnit.NANOSECONDS.toMillis(ticker.read());
        rootRegistry.lazyGaugeInt64(LAST_SUCCESS_SENSOR, Labels.of(), () -> TimeUnit.NANOSECONDS.toMillis(ticker.read()) - lastSuccessEnd);
    }

    public void scheduledRefresh() {
        LOG.info("Refreshing staff cache...");
        Stopwatch stopwatch = Stopwatch.createStarted();
        lastStart = TimeUnit.NANOSECONDS.toMillis(ticker.read());
        boolean success = false;
        try {
            refresh();
            success = true;
        } catch (Throwable e) {
            LOG.error("Failed to refresh staff cache", e);
            throw e;
        } finally {
            stopwatch.stop();
            final long elapsed = stopwatch.elapsed(TimeUnit.MILLISECONDS);
            if (success) {
                lastSuccessEnd = TimeUnit.NANOSECONDS.toMillis(ticker.read());
                LOG.info("Staff cache refresh finished successfully in {} seconds", TimeUnit.MILLISECONDS.toSeconds(elapsed));
            } else {
                errorRate.inc();
                LOG.info("Staff cache refresh finished with error in {} seconds", TimeUnit.MILLISECONDS.toSeconds(elapsed));
            }
            elapsedTime.record(elapsed);
        }
    }

    public void refresh() {
        Stopwatch stopwatch = Stopwatch.createStarted();
        // Load all non-dismissed, non-deleted persons, put them to each cache
        // The same person objects are reused for each cache to reduce memory consumption
        // Missing keys caches are updated accordingly to evict keys of known persons
        final Set<Person> allActive = personDao.getAllActive();
        LOG.info("Active persons loaded in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        final Map<Long, Person> byId = allActive.stream().collect(Collectors.toMap(LongIndexBase::getId, p -> p));
        personsById.putAll(byId);
        LOG.info("Persons by id cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        missingIds.invalidateAll(byId.keySet());
        LOG.info("Missing person id cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        // Invalidate already inactive as they were not refreshed otherwise
        final Set<Long> idsToInvalidate = personsById.asMap().entrySet().stream()
                .filter(e -> !e.getValue().isDismissed() && !e.getValue().isDeleted() && !byId.containsKey(e.getKey()))
                .map(Map.Entry::getKey).collect(Collectors.toSet());
        personsById.invalidateAll(idsToInvalidate);
        LOG.info("Persons by id cache inactive invalidation finished in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        final Map<Long, Person> byUid = allActive.stream().collect(Collectors.toMap(Person::getUid, p -> p));
        personsByUid.putAll(byUid);
        LOG.info("Persons by uid cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        missingUids.invalidateAll(byUid.keySet());
        LOG.info("Missing person uid cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        // Invalidate already inactive as they were not refreshed otherwise
        final Set<Long> uidsToInvalidate = personsByUid.asMap().entrySet().stream()
                .filter(e -> !e.getValue().isDismissed() && !e.getValue().isDeleted() && !byUid.containsKey(e.getKey()))
                .map(Map.Entry::getKey).collect(Collectors.toSet());
        personsByUid.invalidateAll(uidsToInvalidate);
        LOG.info("Persons by uid cache inactive invalidation finished in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        final Map<String, Person> byLogin = allActive.stream().collect(Collectors.toMap(Person::getLogin, p -> p));
        personsByLogin.putAll(byLogin);
        LOG.info("Persons by login cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        missingLogins.invalidateAll(byLogin.keySet());
        LOG.info("Missing person login cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        // Invalidate already inactive as they were not refreshed otherwise
        final Set<String> loginsToInvalidate = personsByLogin.asMap().entrySet().stream()
                .filter(e -> !e.getValue().isDismissed() && !e.getValue().isDeleted() && !byLogin.containsKey(e.getKey()))
                .map(Map.Entry::getKey).collect(Collectors.toSet());
        personsByLogin.invalidateAll(loginsToInvalidate);
        LOG.info("Persons by login cache inactive invalidation finished in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        // Load persons for every group mentioned in project roles, reuse already cached person objects
        final Map<Long, Set<Long>> personIdsByGroupIdsForProjectMembership = personGroupMembershipDao.findPersonIdsByGroupIdsForProjectMembership();
        LOG.info("Persons by groups query finished in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        final Map<Long, Set<Person>> projectPersonsByGroups = personIdsByGroupIdsForProjectMembership
                .entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> {
                    try {
                        return new HashSet<>(personsById.getAll(e.getValue()).values());
                    } catch (ExecutionException ex) {
                        throw new UncheckedExecutionException(ex);
                    }
                }));
        LOG.info("Persons by groups loaded in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        personsByGroupId.putAll(projectPersonsByGroups);
        LOG.info("Persons by groups cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        // Load all groups for every non-dismissed non-deleted person
        final Map<Long, Set<YaGroup>> groupsByActivePerson = personDao.getGroupsForAllActiveUsers();
        LOG.info("Groups by person loaded in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
        stopwatch = Stopwatch.createStarted();
        groupsByPersonId.putAll(groupsByActivePerson);
        LOG.info("Groups by person cache updated in {} ms", stopwatch.elapsed(TimeUnit.MILLISECONDS));
    }

    public void clear() {
        personsById.invalidateAll();
        personsByUid.invalidateAll();;
        personsByLogin.invalidateAll();;
        personsByGroupId.invalidateAll();;
        groupsByPersonId.invalidateAll();;
        missingIds.invalidateAll();;
        missingUids.invalidateAll();;
        missingLogins.invalidateAll();;
    }

    @NotNull
    public Set<Person> getPersonsInGroups(@NotNull final Collection<YaGroup> groups) {
        if (groups.isEmpty()) {
            return Collections.emptySet();
        }
        try {
            return personsByGroupId.getAll(groups.stream().map(LongIndexBase::getId).collect(Collectors.toSet()))
                    .values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
        } catch (ExecutionException e) {
            throw new UncheckedExecutionException(e);
        }
    }

    @NotNull
    public Set<YaGroup> getPersonGroups(@NotNull final Person person) {
        return groupsByPersonId.getUnchecked(person.getId());
    }

    @NotNull
    public Optional<Person> tryGetPersonById(final long id) {
        return getOptional(id, missingIds, personsById);
    }

    @NotNull
    public Optional<Person> tryGetPersonByLogin(@NotNull final String login) {
        return getOptional(login, missingLogins, personsByLogin);
    }

    @NotNull
    public Optional<Person> tryGetPersonByUid(final long uid) {
        return getOptional(uid, missingUids, personsByUid);
    }

    @NotNull
    public Set<Person> tryGetPersonsByIds(@NotNull final Collection<Long> ids) {
        return getMany(ids, missingIds, personsById, LongIndexBase::getId);
    }

    @NotNull
    public Set<Person> tryGetPersonsByLogins(@NotNull final Collection<String> logins) {
        return getMany(logins, missingLogins, personsByLogin, Person::getLogin);
    }

    @NotNull
    public Set<Person> tryGetPersonsByUids(@NotNull final Collection<Long> uids) {
        return getMany(uids, missingUids, personsByUid, Person::getUid);
    }

    private <K, V> Optional<V> getOptional(@NotNull final K key, @NotNull final Cache<K, Boolean> missingCache, @NotNull final LoadingCache<K, V> valueCache) {
        if (missingCache.getIfPresent(key) != null) {
            return Optional.empty();
        }
        try {
            return Optional.of(valueCache.getUnchecked(key));
        } catch (UncheckedExecutionException e) {
            if (e.getCause() instanceof NotFoundException) {
                missingCache.put(key, Boolean.TRUE);
                return Optional.empty();
            }
            throw e;
        }
    }

    private <K, V> Set<V> getMany(@NotNull final Collection<K> keys, @NotNull final Cache<K, Boolean> missingCache,
                                  @NotNull final LoadingCache<K, V> valueCache, @NotNull final Function<V, K> keyMapper) {
        if (keys.isEmpty()) {
            return Collections.emptySet();
        }
        final Set<K> nonMissingKeys = keys.stream().filter(id -> missingCache.getIfPresent(id) == null).collect(Collectors.toSet());
        if (nonMissingKeys.isEmpty()) {
            return Collections.emptySet();
        }
        try {
            return new HashSet<>(valueCache.getAll(nonMissingKeys).values());
        } catch (CacheLoader.InvalidCacheLoadException e) {
            final Set<V> values = nonMissingKeys.stream().map(valueCache::getIfPresent).filter(Objects::nonNull).collect(Collectors.toSet());
            final Set<K> foundKeys = values.stream().map(keyMapper).collect(Collectors.toSet());
            final Set<K> notFoundKeys = Sets.difference(nonMissingKeys, foundKeys);
            missingCache.putAll(notFoundKeys.stream().collect(Collectors.toMap(id -> id, id -> Boolean.TRUE)));
            return values;
        } catch (ExecutionException e) {
            throw new UncheckedExecutionException(e);
        }
    }

    private static abstract class ValueLoader<K, V> extends CacheLoader<K, V> {

        private ValueLoader() {
        }

        @Override
        public V load(final K key) {
            return loadOptional(key).orElseThrow(NotFoundException::new);
        }

        @Override
        public Map<K, V> loadAll(final Iterable<? extends K> keys) {
            final Set<V> values = loadMany(StreamSupport.stream(keys.spliterator(), false).collect(Collectors.toSet()));
            return values.stream().collect(Collectors.toMap(this::getKey, p -> p));
        }

        protected abstract Optional<V> loadOptional(final K key);

        protected abstract Set<V> loadMany(final Collection<K> keys);

        protected abstract K getKey(final V person);
    }

    private static class PersonByIdLoader extends ValueLoader<Long, Person> {

        @NotNull
        private final PersonDao personDao;

        private PersonByIdLoader(@NotNull final PersonDao personDao) {
            this.personDao = personDao;
        }

        @Override
        protected Optional<Person> loadOptional(final Long key) {
            return personDao.tryReadPersonById(key);
        }

        @Override
        protected Set<Person> loadMany(final Collection<Long> keys) {
            return personDao.tryReadPersonsByIds(keys);
        }

        @Override
        protected Long getKey(final Person person) {
            return person.getId();
        }

    }

    private static class PersonByUidLoader extends ValueLoader<Long, Person> {

        @NotNull
        private final PersonDao personDao;

        private PersonByUidLoader(@NotNull final PersonDao personDao) {
            this.personDao = personDao;
        }

        @Override
        protected Optional<Person> loadOptional(final Long key) {
            return personDao.tryReadPersonByUid(key);
        }

        @Override
        protected Set<Person> loadMany(final Collection<Long> keys) {
            return personDao.tryReadPersonsByUids(keys);
        }

        @Override
        protected Long getKey(final Person person) {
            return person.getUid();
        }

    }

    private static class PersonByLoginLoader extends ValueLoader<String, Person> {

        @NotNull
        private final PersonDao personDao;

        private PersonByLoginLoader(@NotNull final PersonDao personDao) {
            this.personDao = personDao;
        }

        @Override
        protected Optional<Person> loadOptional(final String key) {
            return personDao.tryReadPersonByLogin(key);
        }

        @Override
        protected Set<Person> loadMany(final Collection<String> keys) {
            return personDao.tryReadPersonsByLogins(keys);
        }

        @Override
        protected String getKey(final Person person) {
            return person.getLogin();
        }

    }

    private static class PersonsByGroupIdLoader extends CacheLoader<Long, Set<Person>> {

        @NotNull
        private final PersonGroupMembershipDao personGroupMembershipDao;
        @NotNull
        private final LoadingCache<Long, Person> personsById;

        private PersonsByGroupIdLoader(@NotNull final PersonGroupMembershipDao personGroupMembershipDao, @NotNull final LoadingCache<Long, Person> personsById) {
            this.personGroupMembershipDao = personGroupMembershipDao;
            this.personsById = personsById;
        }

        @Override
        public Set<Person> load(final Long key) {
            try {
                return new HashSet<>(personsById.getAll(personGroupMembershipDao.findPersonIdsByGroupId(key)).values());
            } catch (ExecutionException e) {
                throw new UncheckedExecutionException(e);
            }
        }

        @Override
        public Map<Long, Set<Person>> loadAll(final Iterable<? extends Long> keys) {
            return personGroupMembershipDao.findPersonIdsByGroupIds(StreamSupport.stream(keys.spliterator(), false).collect(Collectors.toSet()))
                    .entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> loadPersons(e.getValue())));
        }

        private Set<Person> loadPersons(final Set<Long> personIds) {
            try {
                return new HashSet<>(personsById.getAll(personIds).values());
            } catch (ExecutionException e) {
                throw new UncheckedExecutionException(e);
            }
        }

    }

    private static class GroupsByPersonIdLoader extends CacheLoader<Long, Set<YaGroup>> {

        @NotNull
        private final PersonGroupMembershipDao personGroupMembershipDao;

        private GroupsByPersonIdLoader(@NotNull final PersonGroupMembershipDao personGroupMembershipDao) {
            this.personGroupMembershipDao = personGroupMembershipDao;
        }

        @Override
        public Set<YaGroup> load(final Long key) {
            return personGroupMembershipDao.findGroupsByPersonId(key);
        }

        @Override
        public Map<Long, Set<YaGroup>> loadAll(final Iterable<? extends Long> keys) {
            return personGroupMembershipDao.findGroupsByPersonIds(StreamSupport.stream(keys.spliterator(), false).collect(Collectors.toSet()));
        }

    }

    private static class NotFoundException extends RuntimeException {
    }

}
