package ru.yandex.direct.dbutil.sharding;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import javax.annotation.Nonnull;
import javax.annotation.ParametersAreNonnullByDefault;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import one.util.streamex.EntryStream;
import org.jooq.Cursor;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.InsertValuesStep2;
import org.jooq.Record;
import org.jooq.Record2;
import org.jooq.Result;
import org.jooq.SelectJoinStep;
import org.jooq.Table;
import org.jooq.impl.DSL;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;

import ru.yandex.direct.dbutil.QueryWithForbiddenShardMapping;
import ru.yandex.direct.dbutil.exception.AliveShardNotFoundException;
import ru.yandex.direct.dbutil.exception.NoAvailableShardsException;
import ru.yandex.direct.dbutil.wrapper.DatabaseWrapperProvider;
import ru.yandex.direct.dbutil.wrapper.ShardedDb;
import ru.yandex.direct.dbutil.wrapper.SimpleDb;
import ru.yandex.direct.solomon.SolomonUtils;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static ru.yandex.direct.dbutil.CollectionUtils.weightedShuffle;
import static ru.yandex.direct.dbutil.sharding.ValuesNormalizer.normalizeValue;
import static ru.yandex.direct.utils.FunctionalUtils.mapList;

@ParametersAreNonnullByDefault
@Component
public class ShardSupport {
    /**
     * Used as a placeholder for missing shard number information
     */
    public static final int NO_SHARD = -1;

    private static final int MAX_CHUNK_SIZE = 10000;
    private static final int MAX_CACHE_SIZE = 100000;
    private static final int MAX_REVERSE_CHAIN_SIZE = 5;
    private static final long CACHE_EXPIRE_TIME_SECONDS = 10;

    private final DatabaseWrapperProvider databaseWrapperProvider;
    private final ShardedValuesGenerator valuesGenerator;
    private final List<Integer> availablePpcShards;
    private final ConcurrentHashMap<ShardKey, Cache<Object, Object>> cachesForKeys = new ConcurrentHashMap<>();

    /**
     * Reads shard configuration for the specified stream
     */
    @Autowired
    public ShardSupport(
            DatabaseWrapperProvider databaseWrapperProvider,
            ShardedValuesGenerator valuesGenerator,
            @Value("${db_shards}") int numOfPpcShards) {
        if (numOfPpcShards < 0) {
            throw new IllegalArgumentException("numOfPpcShards: " + numOfPpcShards);
        }
        this.databaseWrapperProvider = databaseWrapperProvider;
        this.valuesGenerator = valuesGenerator;
        // Кол-во shard-ов доступных из конфига может быть
        // больше кол-ва shard-ов доступных для использования
        this.availablePpcShards = ImmutableList.copyOf(
                IntStream.rangeClosed(1, numOfPpcShards).boxed().iterator());
    }

    /**
     * Получить список доступных номеров шардов базы PPC
     */
    public List<Integer> getAvailablePpcShards() {
        return availablePpcShards;
    }

    private Cache<Object, Object> getCache(ShardKey key) {
        Cache<Object, Object> cache = cachesForKeys.get(key);
        if (cache == null) {
            // slow path, create the cache
            cache = cachesForKeys.computeIfAbsent(key, k -> CacheBuilder.newBuilder()
                    .maximumSize(MAX_CACHE_SIZE)
                    .recordStats()
                    .expireAfterWrite(CACHE_EXPIRE_TIME_SECONDS, TimeUnit.SECONDS)
                    .build()
            );

            SolomonUtils.registerGuavaCachesStats("shards_cache", cachesForKeys.values());
        }
        return cache;
    }

    // Work around java generics limitation
    protected <R extends Record, T1, T2> void runGetValuesQuery(Table<R> table, Field<T1> keyField,
                                                                Field<T2> chainField, List<Object> chunk,
                                                                Cache<Object, Object> cache,
                                                                Map<Object, List<Integer>> indexesByValue,
                                                                ShardKey currentKey, Object[] currentValues) {
        Result<Record2<T1, T2>> dbResult = getDslContext()
                .select(keyField, chainField)
                .from(table)
                .where(keyField.in(chunk))
                .fetch();
        for (Record2 kv : dbResult) {
            Object key = normalizeValue(kv.value1(), currentKey);
            List<Integer> indexes = indexesByValue.get(key);
            if (indexes == null) {
                throw new IllegalStateException(
                        "Key '" + currentKey.getName() + "' returned '" + kv.value1()
                                + "' from the database, but it was not requested");
            }
            cache.put(key, kv.value2());
            for (int index : indexes) {
                currentValues[index] = kv.value2();
            }
        }
    }

    /**
     * Returns an array of chainKey values for each key/value, which may be null if there's no such mapping
     */
    public <F, T> List<T> getValues(ShardKey key, Collection<F> values, ShardKey chainKey, Class<T> chainType) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        Assert.notNull(values, "values cannot be null");
        Assert.notNull(chainKey, "chainKey cannot be null");
        // We modify values in-place, don't touch the user-supplied collection
        ShardKey currentKey = key;
        Object[] currentValues = values.toArray(new Object[0]);
        Assert.noNullElements(currentValues, "value cannot be null");
        while (currentKey != chainKey) {
            if (currentKey == null || currentKey.isRoot()) {
                throw new IllegalArgumentException(
                        "Chain key " + chainKey.getName() + " is unreachable from " + key.getName());
            }
            Cache<Object, Object> cache = getCache(currentKey);
            HashMap<Object, List<Integer>> indexesByValue = new HashMap<>(currentValues.length);
            for (int i = 0; i < currentValues.length; ++i) {
                Object value = currentValues[i];
                if (value == null) {
                    // An earlier pass didn't find the object, so skip it
                    continue;
                }
                value = normalizeValue(value, currentKey);
                Object result = cache.getIfPresent(value);
                if (result == null) {
                    // We don't have a cached result, so schedule it for query
                    List<Integer> indexes = indexesByValue.get(value);
                    if (indexes == null) {
                        indexes = new ArrayList<>();
                        indexesByValue.put(value, indexes);
                    }
                    indexes.add(i);
                }
                // This also marks value as empty in case it's never found
                currentValues[i] = result;
            }
            List<Object> valuesForQuery = new ArrayList<>(indexesByValue.keySet());
            for (List<Object> chunk : Lists.partition(valuesForQuery, MAX_CHUNK_SIZE)) {
                runGetValuesQuery(currentKey.getTable(), currentKey.getKeyField(), currentKey.getValueField(), chunk,
                        cache, indexesByValue, currentKey, currentValues);
            }
            currentKey = currentKey.getChainKey();
        }
        List<T> results = new ArrayList<>(currentValues.length);
        for (Object currentValue : currentValues) {
            results.add(normalizeValue(currentValue, chainType));
        }
        return results;
    }

    /**
     * Returns a map of values -> chainKey for each key/value, which may be null if there's no such mapping
     */
    public <F, T> Map<F, T> getValuesMap(ShardKey key, List<F> values, ShardKey chainKey, Class<T> chainType) {
        List<T> results = getValues(key, values, chainKey, chainType);

        Map<F, T> ret = new HashMap<>();
        for (int i = 0; i < values.size(); i++) {
            if (values.get(i) != null) {
                ret.put(values.get(i), results.get(i));
            }
        }

        return ret;
    }

    /**
     * Returns the chainKey value for key/value or null if there's no such mapping
     */
    public <T> T getValue(ShardKey key, Object value, ShardKey chainKey, Class<T> chainType) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        Assert.notNull(value, "value cannot be null");
        Assert.notNull(chainKey, "chainKey cannot be null");
        return getValues(key, Collections.singletonList(value), chainKey, chainType).get(0);
    }

    /**
     * Returns an array of shard numbers for each key/value, or null if there's no shard for that value
     */
    public <T> List<Integer> getShards(ShardKey key, List<T> values) {
        return getValues(key, values, ShardKey.SHARD, Integer.class);
    }

    /**
     * Returns the shard number for key/value or NO_SHARD if there's no shard for that value
     */
    public int getShard(ShardKey key, Object value) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        Assert.notNull(value, "value cannot be null");
        Integer shard = getShards(key, Collections.singletonList(value)).get(0);
        return shard != null ? shard : NO_SHARD;
    }

    @SuppressWarnings("unchecked")
    public <T> T getMaxValue(ShardKey shardKey) {
        //noinspection unchecked
        return (T) normalizeValue(getDslContext()
                .select(DSL.max(shardKey.getKeyField()))
                .from(shardKey.getTable())
                .fetchOne()
                .value1(), shardKey);
    }

    private <T1, T2, T3, T4> SelectJoinStep<Record2<T1, T2>> applyInnerJoinToSelect(
            SelectJoinStep<Record2<T1, T2>> select, Table<?> table, Field<T3> left, Field<T4> right) {
        return select.innerJoin(table).on(right.equal(left.coerce(right)));
    }

    private <T1, T2, T> void runLookupKeysQuery(Table<?> mainTable, Field<T1> keyField, Field<T2> chainField,
                                                List<Table<?>> extraTables, List<Field<?>> intermediateFields,
                                                Map<Object, List<T>> resultsByValue, Class<T> type) {
        SelectJoinStep<Record2<T1, T2>> select =
                getDslContext().select(keyField, chainField).from(mainTable);
        for (int i = 0; i < extraTables.size(); ++i) {
            Table<?> table = extraTables.get(i);
            Field<?> left = intermediateFields.get(i * 2);
            Field<?> right = intermediateFields.get(i * 2 + 1);
            select = applyInnerJoinToSelect(select, table, left, right);
        }
        try (Cursor<Record2<T1, T2>> dbResults = select.where(chainField.in(resultsByValue.keySet())).fetchLazy()) {
            for (Record2<T1, T2> dbResult : dbResults) {
                T1 key = dbResult.value1();
                T2 value = dbResult.value2();
                List<T> results = resultsByValue.get(value);
                if (results == null) {
                    throw new IllegalStateException(
                            "Key '" + keyField.getName() + "' returned '" + chainField.getName() + "' value '"
                                    + value + "' from the database, but it was not requested");
                }
                results.add(normalizeValue(key, type));
            }
        }
    }

    /**
     * Returns an array of lists of key's values that are mapped to each of the specified chainKey/chainValue
     */
    public <T> List<List<T>> lookupKeysWithValues(ShardKey key, ShardKey chainKey, List<?> chainValues, Class<T> type) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        Assert.notNull(chainKey, "chainKey cannot be null");
        Assert.notNull(chainValues, "chainValues cannot be null");
        Assert.isTrue(key != chainKey, "key and chainKey cannot be the same");

        // Prepare all table and fields names from the chain
        List<Table<?>> tables = new ArrayList<>();
        List<Field<?>> fields = new ArrayList<>();
        ShardKey currentKey = key;
        String lastMappedName = key.getName();
        while (currentKey != chainKey) {
            if (currentKey == null || currentKey.isRoot()) {
                throw new IllegalArgumentException(
                        "Chain key " + chainKey.getName() + " is unreachable from " + key.getName());
            }
            tables.add(currentKey.getTable());
            fields.add(currentKey.getKeyField());
            fields.add(currentKey.getValueField());
            currentKey = currentKey.getChainKey();
        }
        if (tables.size() > MAX_REVERSE_CHAIN_SIZE) {
            throw new IllegalArgumentException(
                    "Chain from key " + key.getName() + " to " + lastMappedName + " is too long");
        }

        Table<?> firstTable = tables.get(0);
        Field<?> keyField = fields.get(0);
        Field<?> chainField = fields.get(fields.size() - 1);

        // Gather all unique values and prepare for storing their results
        Object[] normalizedValues = chainValues.toArray();
        HashMap<Object, List<T>> resultsByValue = new HashMap<>(chainValues.size());
        for (int i = 0; i < normalizedValues.length; ++i) {
            Object chainValue = normalizedValues[i];
            Assert.notNull(chainValue, "chainValue cannot be null");
            Object normalizedValue = normalizeValue(chainValue, chainField);
            normalizedValues[i] = normalizedValue;
            if (!resultsByValue.containsKey(normalizedValue)) {
                resultsByValue.put(normalizedValue, new ArrayList<>());
            }
        }

        runLookupKeysQuery(firstTable, keyField, chainField, tables.subList(1, tables.size()),
                fields.subList(1, fields.size() - 1), resultsByValue, type);

        // Return results in the same order as chainValues
        List<List<T>> finalResults = new ArrayList<>(normalizedValues.length);
        for (Object normalizedValue : normalizedValues) {
            finalResults.add(resultsByValue.get(normalizedValue));
        }
        return finalResults;
    }

    /**
     * Returns a list of key's values that are mapped to the specified chainKey/chainValue
     */
    @Nonnull
    public <T> List<T> lookupKeysWithValue(ShardKey key, ShardKey chainKey, Object chainValue, Class<T> type) {
        return lookupKeysWithValues(key, chainKey, Collections.singletonList(chainValue), type).get(0);
    }

    private <R extends Record, T1> void runDeleteValuesQuery(ShardKey key, Table<R> table, Field<T1> keyField,
                                                             Collection<?> keyValues) {
        List<Object> normalizedValues = mapList(keyValues, keyValue -> normalizeValue(keyValue, keyField));
        getDslContext()
                .deleteFrom(table)
                .where(keyField.in(normalizedValues))
                .execute();

        if (cachesForKeys.get(key) != null) {
            cachesForKeys.get(key).invalidateAll(normalizedValues);
        }
    }

    /**
     * delete values by key
     */
    public void deleteValues(ShardKey key, Collection<?> values) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        for (Object value : values) {
            Assert.notNull(value, "value cannot be null");
        }

        runDeleteValuesQuery(key, key.getTable(), key.getKeyField(), values);
    }

    private <R extends Record, T1, T2> void runSaveValuesQuery(Table<R> table, Field<T1> keyField, Field<T2> chainField,
                                                               Collection<?> keyValues, Object chainValue) {
        InsertValuesStep2<R, T1, T2> step = getDslContext().insertInto(table, keyField, chainField);
        for (Object keyValue : keyValues) {
            step = step.values(normalizeValue(keyValue, keyField),
                    normalizeValue(chainValue, chainField));
        }
        step.execute();
    }

    /**
     * в valuesMap ключем является значение параметра key, элементов является значение параметра chainKey
     * метод группирует ключи всех не уникальных элементов и сохраняет одним запросом
     */
    public void saveValues(ShardKey key, ShardKey chainKey, Map<?, ?> valuesMap) {
        checkNotNull(valuesMap, "valuesMap cannot be null");
        checkArgument(!valuesMap.containsKey(null), "key in Map cannot be null");
        checkArgument(!valuesMap.containsValue(null), "element in Map cannot be null");
        EntryStream.of(valuesMap).invert().grouping()
                .forEach((chainValue, values) -> saveValues(key, values, chainKey, chainValue));
    }

    /**
     * Saves an association from key/value to chainKey/chainValue
     */
    public void saveValues(ShardKey key, Collection<?> values, ShardKey chainKey, Object chainValue) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        for (Object value : values) {
            Assert.notNull(value, "value cannot be null");
        }
        Assert.notNull(chainKey, "chainKey cannot be null");
        Assert.notNull(chainValue, "chainValue cannot be null");
        if (key.getChainKey() != chainKey) {
            chainValue = getValue(chainKey, chainValue, key.getChainKey(), Object.class);
            if (chainValue == null) {
                throw new IllegalStateException("Cannot find correct value for " + key.getName());
            }
        }
        runSaveValuesQuery(key.getTable(), key.getKeyField(), key.getValueField(), values, chainValue);
    }

    /**
     * Saves an association from key/value to chainKey/chainValue
     */
    public void saveValue(ShardKey key, Object value, ShardKey chainKey, Object chainValue) {
        saveValues(key, Collections.singletonList(value), chainKey, chainValue);
    }

    /**
     * Generates a list of new value for key
     */
    public List<Number> generateValues(AutoIncrementKey key, int count) {
        return valuesGenerator.generateValues(key, count);
    }

    /**
     * Generates a new value for key that is associated with chainKey/chainValue
     */
    public Number generateValue(ShardKey key, ShardKey chainKey, Object chainValue) {
        return generateValues(key, chainKey, singletonList(chainValue)).iterator().next();
    }

    /**
     * Generates new values for key that is associated with chainKey
     */
    @QueryWithForbiddenShardMapping(value = "Генерация новых ключей")
    public List<Number> generateValues(ShardKey key, ShardKey chainKey, List<?> chainValues) {
        Assert.notNull(key, "key cannot be null");
        Assert.isTrue(!key.isRoot(), "key must be a valid shard key");
        Assert.isTrue(key.isAutoIncrement(), "key is not auto incremented");
        Assert.notNull(chainKey, "chainKey cannot be null");
        chainValues.forEach(chainValue ->
                Assert.notNull(chainValue, "chainValue cannot be null"));

        if (key.getChainKey() == chainKey) {
            return valuesGenerator.generateValues(key, chainValues);
        }

        List<Object> mappedValues = getValues(chainKey, chainValues, key.getChainKey(), Object.class);

        mappedValues.forEach(mappedValue ->
                checkState(mappedValue != null, "Cannot find correct value for %s", key.getName()));

        return valuesGenerator.generateValues(key, mappedValues);
    }

    /**
     * Выбираем shard для заданного нового клиента (Аналог ShardingTools.get_new_available_shard)
     * <p>
     * В отличии от ShardingTools.get_new_available_shard не пытаемся размещать клиента и
     * оператора в одном shard-е, а выбираем случайный доступный shard с учетом его веса
     */
    public int selectShardForNewClient() {
        List<ShardWeight> availableShards = this.availablePpcShards.stream()
                .map(shard -> databaseWrapperProvider.getShardWeight(ShardedDb.PPC, shard))
                .collect(toList());
        if (availableShards.isEmpty()) {
            throw new NoAvailableShardsException();
        }

        List<ShardWeight> shuffledShards = weightedShuffle(
                ThreadLocalRandom.current(), availableShards, ShardWeight::getWeight);

        for (ShardWeight shardWeight : shuffledShards) {
            if (databaseWrapperProvider.isAlive(ShardedDb.PPC, shardWeight.getShardNo())) {
                return shardWeight.getShardNo();
            }
        }

        // Все shard-ы перебрали, ничего больше не остается кроме того как ругнуться
        throw new AliveShardNotFoundException();
    }

    private DSLContext getDslContext() {
        return databaseWrapperProvider.get(SimpleDb.PPCDICT).getDslContext();
    }
}
