package ru.yandex.analyzer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;

import java.io.IOException;

import ru.yandex.util.string.HexStrings;

/**
 * A {@link TokenFilter} that multiplies and takes floor of the number.
 */
public final class MultiTokenPermuteFilter extends TokenFilter {
    private static final ConcurrentHashMap<Long, int[][]> CACHE =
        new ConcurrentHashMap<>();
    private final CharTermAttribute termAtt;
    private final PositionIncrementAttribute posAtt;
    private final int n;
    private final int r;
    private final int[][] permutations;
    private int currentPermutation = -1;
    private Token[] tokens;
    private int tokenCount;
    private int maxLen;

    public MultiTokenPermuteFilter(final TokenStream in, final int n, final int r) {
        super(in);
        termAtt = addAttribute(CharTermAttribute.class);
        posAtt = addAttribute(PositionIncrementAttribute.class);
        this.n = n;
        this.r = r;
        permutations = loadPermutations(n, r);
    }

    private static int[][] loadPermutations(final int n, final int r) {
        long key = n + (r << 32);
        int[][] perms = CACHE.get(key);
        if (perms == null) {
            perms = generate(n, r);
            CACHE.put(key, perms);
        }
        return perms;
    }

    /**
    * Generate all combinations of r elements from a set
    * @param n the number of elements in input set
    * @param r the number of elements in a combination
    * @return the list containing all combinations
    */
    public static int[][] generate(int n, int r) {
        ArrayList<int[]> combinations = new ArrayList<>();
        int[] combination = new int[r];
        // initialize with lowest lexicographic combination
        for (int i = 0; i < r; i++) {
            combination[i] = i;
        }
        while (combination[r - 1] < n) {
            combinations.add(combination.clone());
            // generate next combination in lexicographic order
            int t = r - 1;
            while (t != 0 && combination[t] == n - r + t) {
                t--;
            }
            combination[t]++;
            for (int i = t + 1; i < r; i++) {
                combination[i] = combination[i - 1] + 1;
            }
        }
        ArrayList<int[]> combs = new ArrayList<>();
        int k = n - r;
        int[] allNums = new int[n];
        for (int i = 0; i < n; i++) {
            allNums[i] = i;
        }
        for (int i = 0; i < combinations.size(); i++) {
            int[] newComb = new int[n];
            int[] comb = combinations.get(i);
            for (int j = n - r, t = 0; j < n; j++, t++) {
                newComb[j] = comb[t];
            }
            int minStart = 0;
            for (int j = 0; j < k; j++) {
                for (int t = minStart; t < n; t++) {
                    int num = allNums[t];
                    boolean contains = false;
                    for (int u = 0; u < comb.length; u++) {
                        if (comb[u] == num) {
                            contains = true;
                            break;
                        }
                    }
                    if (!contains) {
                        newComb[j] = num;
                        minStart = t + 1;
                        break;
                    }
                }
            }
            combs.add(newComb);
        }
        return combs.toArray(new int[0][]);
    }

    @Override
    public final boolean incrementToken() throws IOException {
        if (currentPermutation == -1
            || currentPermutation >= permutations.length)
        {
            tokenCount = 0;
            maxLen = 0;
            while (input.incrementToken()) {
                if (tokens == null) {
                    tokens = new Token[32];
                } else if (tokens.length <= tokenCount) {
                    tokens = Arrays.copyOf(tokens, tokenCount << 1);
                }
                int len = termAtt.length();
                char[] buffer = termAtt.buffer();
                if (tokens[tokenCount] == null) {
                    tokens[tokenCount] = new Token(buffer, len);
                } else {
                    tokens[tokenCount].copyFrom(buffer, len);
                }
                tokenCount++;
                maxLen += len + 1;
            }

            if (tokenCount > 0) {
                currentPermutation = 0;
                if (termAtt.length() < maxLen) {
                    termAtt.resizeBuffer(maxLen);
                }
                char[] buffer = termAtt.buffer();
                int plen = generatePermutedBlock(currentPermutation++, buffer);
                termAtt.setLength(plen);
                return true;
            } else {
                return false;
            }
        } else {
            if (termAtt.length() < maxLen) {
                termAtt.resizeBuffer(maxLen);
            }
            char[] buffer = termAtt.buffer();
            int len = generatePermutedBlock(currentPermutation++, buffer);
            termAtt.setLength(len);
            posAtt.setPositionIncrement(0);
            return true;
        }
    }

    private int generatePermutedBlock(
        final int permutationIdx,
        final char[] buffer)
    {
        int[] perm = permutations[permutationIdx];
        int len = 0;
        int k = n - r;
        float tokensPerBlock = (float) tokenCount / (float) n;
//        System.err.println("tokensPerBlock: " + tokensPerBlock);
        for (int i = 0; i < k; i++) {
            int blockIdx = perm[i];
            int blockStartToken = (int) (tokensPerBlock * blockIdx);
            int blockEndToken = (int) (tokensPerBlock * (blockIdx + 1));
            for (int j = blockStartToken; j < blockEndToken; j++) {
                char[] token = tokens[j].buf;
                int tokenLen = tokens[j].len;
                System.arraycopy(token, 0, buffer, len, tokenLen);
                len += tokenLen;
                buffer[len++] = '|';
            }
        }
        return len;
    }

    private static class Token {
        private char[] buf;
        private int len;

        Token(final char[] buf, final int len) {
            this.buf = Arrays.copyOf(buf, len);
            this.len = len;
        }

        public void copyFrom(final char[] buf, final int len) {
            if (this.buf.length < len) {
                this.buf = Arrays.copyOf(buf, len);
            } else {
                System.arraycopy(buf, 0, this.buf, 0, len);
            }
            this.len = len;
        }
    }
}
