package ru.yandex.webmaster3.core.util.trie;

import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;

/**
 * @author avhaliullin
 */
public class CompactTrie<V> {
    static final byte F_STRIKE_NODE = (byte) (1 << 7);
    static final byte F_HAVE_DATA = 1 << 6;
    static final byte F_HAVE_JUMP = 1 << 5;
    static final byte F_STRIKE_LEN_MASK = 0b00111111;
    static final byte F_SPREAD_TABLE_SIZE_MASK = 0b00011111;

    static final int SPREAD_TABLE_ITEM_SIZE = 4;

    private final byte[] data;
    private final int dataSize;
    private final TrieValueCodec<V> valueCodec;

    private CompactTrie(byte[] data, int dataSize, TrieValueCodec<V> valueCodec) {
        this.data = data;
        this.dataSize = dataSize;
        this.valueCodec = valueCodec;
    }

    public int getDataSizeBytes() {
        return data.length;
    }

    public void foreach(EntryConsumer<V> entryConsumer, KeyConsumer keyConsumer) {
        if (entryConsumer != null && keyConsumer != null) {
            throw new IllegalArgumentException();
        }
        TrieBufferWriter keyBuff = new TrieBufferWriter();
        TrieBufferReader reader = new TrieBufferReader(data);
        foreach(entryConsumer, keyConsumer, keyBuff, 0, reader);
    }

    public ITrieNode<V> getRootNode() {
        return new CompactTrieNodeFactory<V>(valueCodec, new TrieBufferReader(data)).getRootNode();
    }

    public TrieValueCodec<V> getValueCodec() {
        return valueCodec;
    }

    private V handleData(EntryConsumer<V> entryConsumer, TrieBufferReader reader) {
        if (entryConsumer != null) {
            return valueCodec.read(reader);
        } else {
            valueCodec.skip(reader);
            return null;
        }
    }

    private void accept(EntryConsumer<V> entryConsumer, KeyConsumer keyConsumer, byte[] key, int keyLen, V value) {
        if (entryConsumer != null && keyConsumer != null) {
            throw new IllegalArgumentException();
        }
        if (entryConsumer != null) {
            entryConsumer.accept(key, keyLen, value);
        } else {
            keyConsumer.accept(key, keyLen);
        }
    }

    private void foreach(EntryConsumer<V> entryConsumer, KeyConsumer keyConsumer, TrieBufferWriter keyBuff, int keyOffset, TrieBufferReader reader) {
        while (true) {
            byte header = reader.readByte();
            boolean haveData = haveData(header);
            if (isStrike(header)) {
                //strike
                int strikeLen = readStrikeLen(reader, header);

                V data = null;
                if (haveData) {
                    data = handleData(entryConsumer, reader);
                }
                keyBuff.assureSize(keyOffset + strikeLen);
                reader.readBytes(keyBuff.getBuffer(), keyOffset, strikeLen);
                if (haveData) {
                    accept(entryConsumer, keyConsumer, keyBuff.getBuffer(), keyOffset, data);
                }
                keyOffset += strikeLen;
                if (strikeLen == 0) { // листовая нода с данными
                    return;
                }
            } else {
                //spread
                boolean haveJump = spreadHaveJump(header);
                int spreadTableSize = readSpreadTableSize(reader, header);
                byte baseChar = reader.readByte();
                if (haveData) {
                    accept(entryConsumer, keyConsumer, keyBuff.getBuffer(), keyOffset, valueCodec.read(reader));
                }
                int spreadTableOffset = reader.getOffset();
                reader.skip(SPREAD_TABLE_ITEM_SIZE * (spreadTableSize - (haveJump ? 0 : 1)));
                byte curChar = baseChar;
                // после текущей ноды идет либо следующий spread (haveJump == true),
                // либо первый ребенок этого спреда, не отмеченный в таблице (haveJump == false)
                if (haveJump) {
                    foreach(entryConsumer, keyConsumer, keyBuff, keyOffset, reader);
                } else {
                    keyBuff.writeByte(keyOffset, curChar);
                    foreach(entryConsumer, keyConsumer, keyBuff, keyOffset + 1, reader);
                    curChar++;
                }
                for (byte i = 0; i < spreadTableSize - (haveJump ? 0 : 1); i++) {
                    // можно было просто пробежаться по offset'ам из таблицы, но есть предположение,
                    // что это может оказаться медленнее из-за cache miss'ов при random access
                    keyBuff.writeByte(keyOffset, curChar);
                    curChar++;
                    int curNodePosition = reader.getOffset();
                    reader.setOffset(spreadTableOffset + i * SPREAD_TABLE_ITEM_SIZE);
                    // при этом offset'ы нам все равно придется вычитывать, потому что в них может быть написан 0,
                    // значит для текущего символа нет ноды
                    int nextNode = reader.readInt();
                    reader.setOffset(curNodePosition);
                    if (nextNode == 0) {
                        continue;
                    }
                    // раз уж у нас на руках оказалась информация о смещении из двух источников - грех не проверить, что она сходится
                    if (nextNode != curNodePosition) {
                        throw new RuntimeException("Position from table " + nextNode + ", actual " + curNodePosition);
                    }
                    foreach(entryConsumer, keyConsumer, keyBuff, keyOffset + 1, reader);
                }
                return;
            }
        }
    }

    public V get(byte[] key) {
        TrieBufferReader reader = new TrieBufferReader(data);
        int keyOffset = 0;
        while (true) {
            boolean consumedKey = keyOffset == key.length;
            byte header = reader.readByte();
            boolean haveData = haveData(header);
            if (isStrike(header)) {
                //strike
                int strikeLen = readStrikeLen(reader, header);
                if (consumedKey) {
                    if (haveData) {
                        return valueCodec.read(reader);
                    } else {
                        return null; // ключ сожрали, а в ноде данных нет
                    }
                } else {
                    if (haveData) {
                        valueCodec.skip(reader);
                    }
                }
                if (strikeLen > key.length - keyOffset) {
                    return null; // ключ заканчивается где-то посреди strike'а, данных там нет
                }
                int cmp = reader.compare(key, keyOffset, strikeLen);

                keyOffset += strikeLen;
                if (cmp == 0) {
                    // проваливаемся дальше
                    continue;
                } else {
                    return null; // просто не совпала подстрока. Из strike-ноды ветвлений нет, так что это конец
                }
            } else {
                //spread
                boolean haveJump = spreadHaveJump(header);
                int spreadTableSize = readSpreadTableSize(reader, header);
                byte baseChar = reader.readByte();
                int spreadTableSizeBytes = (spreadTableSize - (haveJump ? 0 : 1)) * SPREAD_TABLE_ITEM_SIZE;
                if (consumedKey) {
                    if (haveData) {
                        return valueCodec.read(reader);
                    } else {
                        return null;
                    }
                }
                byte curChar = key[keyOffset];
                if (curChar < baseChar) {
                    return null; // ветвлений влево не бывает. Если нужно влево - значит нет такого ключа
                } else if (curChar < baseChar + spreadTableSize) { // попадаем куда-то внутрь таблицы, многообещающая ветка
                    if (haveData) {
                        valueCodec.skip(reader);
                    }
                    int tablePosition = curChar - baseChar;
                    if (!haveJump) {
                        // если за spread'ом нет следующего, то первый offset в таблице пропущен
                        tablePosition--;
                        if (tablePosition < 0) {
                            // нам нужна первая нода в таблице - она просто следует за текущей (после таблицы)
                            reader.skip(spreadTableSizeBytes);
                            keyOffset++;
                            continue;
                        }
                    }
                    reader.skip(tablePosition * SPREAD_TABLE_ITEM_SIZE);
                    int jumpOffset = reader.readInt();
                    if (jumpOffset == 0) {
                        return null; // попали на дырку в spread'е, значит перехода по этому символу нет
                    }
                    reader.setOffset(jumpOffset);
                    keyOffset++;
                } else { // промахнулись вправо от таблицы, еще есть шанс попасть в следующий spread
                    if (!haveJump) {
                        return null; //нет, этот был последний
                    } else {
                        if (haveData) {
                            valueCodec.skip(reader);
                        }
                        reader.skip(spreadTableSizeBytes);
                    }
                }
            }
        }
    }

    public boolean contains(byte[] key) {
        TrieBufferReader reader = new TrieBufferReader(data);
        int keyOffset = 0;
        while (true) {
            boolean consumedKey = keyOffset == key.length;
            byte header = reader.readByte();
            boolean haveData = haveData(header);
            if (isStrike(header)) {
                //strike
                int strikeLen = readStrikeLen(reader, header);
                if (consumedKey) {
                    return haveData;
                } else {
                    if (haveData) {
                        valueCodec.skip(reader);
                    }
                }
                if (strikeLen > key.length - keyOffset) {
                    return false; // ключ заканчивается где-то посреди strike'а, данных там нет
                }
                int cmp = reader.compare(key, keyOffset, strikeLen);

                keyOffset += strikeLen;
                if (cmp == 0) {
                    // проваливаемся дальше
                    continue;
                } else {
                    return false; // просто не совпала подстрока. Из strike-ноды ветвлений нет, так что это конец
                }
            } else {
                //spread
                boolean haveJump = spreadHaveJump(header);
                int spreadTableSize = readSpreadTableSize(reader, header);
                byte baseChar = reader.readByte();
                int spreadTableSizeBytes = (spreadTableSize - (haveJump ? 0 : 1)) * SPREAD_TABLE_ITEM_SIZE;
                if (consumedKey) {
                    return haveData;
                }
                byte curChar = key[keyOffset];
                if (curChar < baseChar) {
                    return false; // ветвлений влево не бывает. Если нужно влево - значит нет такого ключа
                } else if (curChar < baseChar + spreadTableSize) { // попадаем куда-то внутрь таблицы, многообещающая ветка
                    if (haveData) {
                        valueCodec.skip(reader);
                    }
                    int tablePosition = curChar - baseChar;
                    if (!haveJump) {
                        // если за spread'ом нет следующего, то первый offset в таблице пропущен
                        tablePosition--;
                        if (tablePosition < 0) {
                            // нам нужна первая нода в таблице - она просто следует за текущей (после таблицы)
                            reader.skip(spreadTableSizeBytes);
                            keyOffset++;
                            continue;
                        }
                    }
                    reader.skip(tablePosition * SPREAD_TABLE_ITEM_SIZE);
                    int jumpOffset = reader.readInt();
                    if (jumpOffset == 0) {
                        return false; // попали на дырку в spread'е, значит перехода по этому символу нет
                    }
                    reader.setOffset(jumpOffset);
                    keyOffset++;
                } else { // промахнулись вправо от таблицы, еще есть шанс попасть в следующий spread
                    if (!haveJump) {
                        return false; //нет, этот был последний
                    } else {
                        if (haveData) {
                            valueCodec.skip(reader);
                        }
                        reader.skip(spreadTableSizeBytes);
                    }
                }
            }
        }
    }

    public CompactTrie<V> shrink() {
        if (data.length * 1.1 > dataSize) {
            byte[] result = new byte[dataSize];
            System.arraycopy(data, 0, result, 0, dataSize);
            return new CompactTrie<>(result, dataSize, valueCodec);
        } else {
            return this;
        }
    }

    static int readSpreadTableSize(TrieBufferReader reader, byte header) {
        int spreadTableSize = header & F_SPREAD_TABLE_SIZE_MASK;
        if (spreadTableSize == F_SPREAD_TABLE_SIZE_MASK) {
            spreadTableSize = reader.readRawVarUInt32() + F_SPREAD_TABLE_SIZE_MASK;
        }
        return spreadTableSize + 1;
    }

    static int readStrikeLen(TrieBufferReader reader, byte header) {
        int strikeLen = header & F_STRIKE_LEN_MASK;
        if (strikeLen == F_STRIKE_LEN_MASK) {
            strikeLen = reader.readRawVarUInt32() + F_STRIKE_LEN_MASK;
        }
        return strikeLen;
    }

    static boolean spreadHaveJump(byte headerByte) {
        return (headerByte & F_HAVE_JUMP) != 0;
    }

    static boolean haveData(byte headerByte) {
        return (headerByte & F_HAVE_DATA) != 0;
    }

    static boolean isStrike(byte headerByte) {
        return (headerByte & F_STRIKE_NODE) != 0;
    }

    public static <V> CompactTrie<V> fromNode(CompactTrieBuilder<V> trie, TrieValueCodec<V> driver) {
        return fromNode(trie.root, driver);
    }

    public static <V> CompactTrie<V> fromNode(ITrieNode<V> trie, TrieValueCodec<V> driver) {
        TrieBufferWriter buffer = new TrieBufferWriter();
        TrieBufferWriter supportBuffer = new TrieBufferWriter();
        int size = writeNode(trie, buffer, 0, driver, null, supportBuffer);
        return new CompactTrie<>(buffer.getBuffer(), size, driver);
    }

    private static <V> int writeNode(ITrieNode<V> node, TrieBufferWriter buffer, int offset, TrieValueCodec<V> driver, V data, TrieBufferWriter supportBuffer) {
        if (node == null) { // dummy-нода для записи данных
            if (data == null) {
                throw new RuntimeException("Useless node");
            }
            // у любой ноды должен быть заголовок и тип - для таких фейковых нод используем strike с пустой подстрокой
            offset = writeStrikeHeader(buffer, offset, 0, true);
            return driver.write(data, buffer, offset);
        } else if (node.getHiChild() == null) {
            // strike, будем собирать подстроку какого-то ключа (длиной 1 или больше)
            ITrieNode<V> cur = node;
            int strikeLen = 0;
            strikeLen = supportBuffer.writeByte(strikeLen, cur.getLabel());
            // собираем в supportBuffer подстроку, пока не столкнемся с ветвлением или необходимостью записать данные
            while (cur.getMidChild() != null && cur.getMidChild().getHiChild() == null && cur.getData() == null) {
                cur = cur.getMidChild();
                strikeLen = supportBuffer.writeByte(strikeLen, cur.getLabel());
            }

            offset = writeStrikeHeader(buffer, offset, strikeLen, data != null); // заголовок
            if (data != null) {
                offset = driver.write(data, buffer, offset); // данные, если есть
            }
            offset = buffer.writeBytes(offset, supportBuffer.getBuffer(), 0, strikeLen); // strike-подстрока
            // неочевидный момент - в этом месте всегда или cur.getMidChild(), или cur.getData() - не null
            // если они оба null - то мы прошли по какому-то символу, за которым нет ни данных, ни продолжения строки, чего быть не может
            return writeNode(cur.getMidChild(), buffer, offset, driver, cur.getData(), supportBuffer);
        } else {
            // spread - тут ветвление по разным символам, компактим в таблички
            List<List<ITrieNode<V>>> spreads = new ArrayList<>();
            List<ITrieNode<V>> currentSpread = null;

            ITrieNode<V> cur = node;
            byte prevByte = 0;
            // собираем все ветвления в плотные группы
            while (cur != null) {
                // теоретически, чтобы памяти расходовать как можно меньше - дыр вообще не должно быть, но
                // 1) теоретически, чем больше допускаем дыр, тем меньше спредов получится и тем быстрее будет работать поиск по ключу
                // 2) на практике, если дыр вообще не допускать - то памяти в итоге тратится еще больше, чем с дырами, что странно
                if (currentSpread == null || cur.getLabel() > prevByte + 1) { //TODO пока подобрали более-менее оптимальное значение, но надо поисследовать
                    currentSpread = new ArrayList<>();
                    spreads.add(currentSpread);
                }
                currentSpread.add(cur);
                prevByte = cur.getLabel();
                cur = cur.getHiChild();
            }
            IdentityHashMap<ITrieNode<V>, Integer> jumpOffsets = new IdentityHashMap<>();
            for (int spreadIdx = 0; spreadIdx < spreads.size(); spreadIdx++) {
                // данные записываем только в первую ноду - именно в нее мы придем при поиске
                boolean writeData = data != null && spreadIdx == 0;
                List<ITrieNode<V>> spread = spreads.get(spreadIdx);
                byte baseChar = spread.get(0).getLabel();
                byte lastChar = spread.get(spread.size() - 1).getLabel();
                boolean last = spreadIdx == spreads.size() - 1;
                int tableSize = lastChar - baseChar + 1;
                offset = writeSpreadHeader(buffer, offset, baseChar, tableSize, writeData, !last); //заголовок
                if (writeData) {
                    offset = driver.write(data, buffer, offset); //данные, если есть
                }
                // если это последний spread на уровне, то можно соптимизировать таблицу переходов:
                // 1. Гарантируем, что сразу за этой нодой будет лежать первая нода в таблице
                // 2. Не записываем offset первой ноды в таблицу
                int implicitFirstJump = last ? 1 : 0;
                int tableSizeBytes = SPREAD_TABLE_ITEM_SIZE * (tableSize - implicitFirstJump);
                int jumpPositionOffset = offset;
                // двигаемся за spread-таблицу. Мы заранее не знаем, сколько займут дети, поэтому их смещения запишем, когда запишем самих детей
                offset += tableSizeBytes;

                byte curChar = (byte) (baseChar + implicitFirstJump);
                // все еще помним, что при last==true мы не пишем первый offset
                for (int tableIdx = implicitFirstJump; tableIdx < spread.size(); tableIdx++) {
                    ITrieNode<V> spreadNode = spread.get(tableIdx);
                    while (curChar < spreadNode.getLabel()) { // у нас могли быть дыры - нужно их пролистать
                        jumpPositionOffset += SPREAD_TABLE_ITEM_SIZE;
                        curChar++;
                    }
                    jumpOffsets.put(spreadNode, jumpPositionOffset);
                }
            }
            // записываем детей из spread'ов, итерируемся в обратном порядке,
            // потому что сразу после записи последнего спреда мы ожидаем увидеть его первого ребенка
            for (List<ITrieNode<V>> spread : Lists.reverse(spreads)) {
                for (ITrieNode<V> spreadNode : spread) {
                    Integer jumpOffset = jumpOffsets.get(spreadNode);
                    if (jumpOffset != null) { // бывает null для дыр - оставляем ноль. Так мы при чтении поймем, что тут дыра
                        buffer.writeInt(jumpOffset, offset);
                    }
                    offset = writeNode(spreadNode.getMidChild(), buffer, offset, driver, spreadNode.getData(), supportBuffer);
                    spreadNode.freeMemory();
                }
            }
            return offset;
        }
    }

    /**
     * Раскладка по битам
     * 0 _ _ _ _ _ _ _ - тип ноды (spread, бит выставлен в 0)
     * 0 x _ _ _ _ _ _ - есть ли данные
     * 0 _ x _ _ _ _ _ - jump-флаг - есть ли следующий spread за этим
     * 0 _ _ x x x x x - 5 бит на размер spread-таблицы
     *
     * Если размер таблицы не влазит в 5 бит - запишем в конце все единицы, а размер запишем дальше
     */
    private static int writeSpreadHeader(TrieBufferWriter buffer, int offset, byte baseChar, int tableSize, boolean haveData, boolean haveJump) {
        tableSize--; // пустая таблица не имеет смысла, поэтому сместим размер на 1

        byte tableSizeInHeader;
        boolean bigTable = tableSize >= F_SPREAD_TABLE_SIZE_MASK;
        if (bigTable) {
            tableSizeInHeader = F_SPREAD_TABLE_SIZE_MASK;
        } else {
            tableSizeInHeader = (byte) tableSize;
        }
        byte flags = tableSizeInHeader;
        if (haveData) {
            flags |= F_HAVE_DATA;
        }
        if (haveJump) {
            flags |= F_HAVE_JUMP;
        }
        offset = buffer.writeByte(offset, flags);
        if (bigTable) {
            // раз размер не влез в 2^5 - 1, мы можем вычесть 2^5 - 1 из размера, при чтении добавить обратно
            // вычитать имеет смысл потому, что writeVarUInt32 записывает меньшие числа компактнее
            offset = buffer.writeVarUInt32(offset, tableSize - F_SPREAD_TABLE_SIZE_MASK);
        }
        offset = buffer.writeByte(offset, baseChar);
        return offset;
    }

    /**
     * Раскладка по битам
     * 1 _ _ _ _ _ _ _ - тип ноды (strike, бит выставлен в 1)
     * 1 x _ _ _ _ _ _ - есть ли данные
     * 1 _ x x x x x x - 6 бит на длину подстроки
     *
     * Если размер длина подстроки не влазит в 6 бит - запишем его дальше отдельно
     */
    private static int writeStrikeHeader(TrieBufferWriter buffer, int offset, int strikeLen, boolean haveData) {
        byte strikeLenInHeader;
        boolean bigStrike = strikeLen >= F_STRIKE_LEN_MASK;
        if (bigStrike) {
            strikeLenInHeader = F_STRIKE_LEN_MASK;
        } else {
            strikeLenInHeader = (byte) strikeLen;
        }
        byte flags = (byte) (F_STRIKE_NODE | strikeLenInHeader);
        if (haveData) {
            flags |= F_HAVE_DATA;
        }
        offset = buffer.writeByte(offset, flags);
        if (bigStrike) {
            // раз размер не влез в 2^6 - 1, мы можем вычесть 2^6 - 1 из размера, при чтении добавить обратно
            // вычитать имеет смысл потому, что writeVarUInt32 записывает меньшие числа компактнее
            offset = buffer.writeVarUInt32(offset, strikeLen - F_STRIKE_LEN_MASK);
        }
        return offset;
    }

    public interface EntryConsumer<V> {
        void accept(byte[] key, int keyLen, V value);
    }

    public interface KeyConsumer {
        void accept(byte[] key, int keyLen);
    }
}
