package ru.yandex.webmaster3.core.util.trie;

import org.apache.commons.lang3.tuple.Pair;

import static ru.yandex.webmaster3.core.util.trie.CompactTrie.SPREAD_TABLE_ITEM_SIZE;

/**
 * @author avhaliullin
 */
public class CompactTrieNodeFactory<V> {
    private final TrieValueCodec<V> valueCodec;
    private final TrieBufferReader reader;

    public CompactTrieNodeFactory(TrieValueCodec<V> valueCodec, TrieBufferReader reader) {
        this.valueCodec = valueCodec;
        this.reader = reader;
    }

    public ITrieNode<V> getRootNode() {
        return unpackNode(0).getLeft();
    }

    private Pair<TrieNode, V> unpackNode(int offset) {
        reader.setOffset(offset);
        byte header = reader.readByte();
        boolean haveData = CompactTrie.haveData(header);
        if (CompactTrie.isStrike(header)) {
            int strikeLen = CompactTrie.readStrikeLen(reader, header);
            V data = null;
            if (haveData) {
                data = valueCodec.read(reader);
            }
            TrieNode node = null;
            if (strikeLen > 0) {
                node = new StrikeTrieNode(reader.offset, strikeLen, 0);
            }
            return Pair.of(node, data);
        } else {
            int spreadTableSize = CompactTrie.readSpreadTableSize(reader, header);
            byte baseChar = reader.readByte();
            V data = null;
            if (haveData) {
                data = valueCodec.read(reader);
            }
            boolean haveJump = CompactTrie.spreadHaveJump(header);
            int spreadIdx = 0;
            if (!haveJump) {
                spreadTableSize--;
                spreadIdx--;
            }
            TrieNode node = new SpreadTrieNode(reader.getOffset(), spreadTableSize, spreadIdx, baseChar, haveJump);
            return Pair.of(node, data);
        }
    }

    abstract class TrieNode implements ITrieNode<V> {
    }

    class StrikeTrieNode extends TrieNode {
        private final int strikeStartOffset;
        private final int strikeLen;
        private final int strikeIndex;
        private V data;
        private TrieNode midChild;
        private boolean midResolved;

        public StrikeTrieNode(int strikeStartOffset, int strikeLen, int strikeIndex) {
            this.strikeStartOffset = strikeStartOffset;
            this.strikeLen = strikeLen;
            this.strikeIndex = strikeIndex;
        }

        @Override
        public void freeMemory() {
            data = null;
            midChild = null;
            midResolved = false;
        }

        private void resolveMid() {
            if (midResolved) {
                return;
            }
            if (strikeIndex == strikeLen - 1) {
                Pair<TrieNode, V> pair = unpackNode(strikeStartOffset + strikeLen);
                midChild = pair.getLeft();
                data = pair.getRight();
            } else {
                midChild = new StrikeTrieNode(strikeStartOffset, strikeLen, strikeIndex + 1);
            }
            midResolved = true;
        }

        @Override
        public V getData() {
            resolveMid();
            return data;
        }

        @Override
        public ITrieNode<V> getHiChild() {
            return null;
        }

        @Override
        public ITrieNode<V> getMidChild() {
            resolveMid();
            return midChild;
        }

        @Override
        public byte getLabel() {
            return reader.buffer[strikeStartOffset + strikeIndex];
        }
    }

    class SpreadTrieNode extends TrieNode {
        private final int spreadTableStartOffset;
        private final int spreadTableSize;
        private final int spreadTableIndex;
        private final byte curChar;
        private final boolean haveJump;
        private TrieNode hiChild;
        private TrieNode midChild;
        private V data;
        private boolean midResolved;
        private boolean hiResolved;

        public SpreadTrieNode(int spreadTableStartOffset, int spreadTableSize, int spreadTableIndex, byte curChar, boolean haveJump) {
            this.spreadTableStartOffset = spreadTableStartOffset;
            this.spreadTableSize = spreadTableSize;
            this.spreadTableIndex = spreadTableIndex;
            this.curChar = curChar;
            this.haveJump = haveJump;
        }

        @Override
        public void freeMemory() {
            hiChild = null;
            hiResolved = false;
            midChild = null;
            midResolved = false;
            data = null;
        }

        private int nextNodeOffset() {
            return spreadTableStartOffset + spreadTableSize * SPREAD_TABLE_ITEM_SIZE;
        }

        private int readSpreadTable(int index) {
            reader.setOffset(spreadTableStartOffset + index * SPREAD_TABLE_ITEM_SIZE);
            return reader.readInt();
        }

        private void resolveMid() {
            if (midResolved) {
                return;
            }
            Pair<TrieNode, V> pair;
            if (spreadTableIndex < 0) {
                pair = unpackNode(nextNodeOffset());
            } else {
                pair = unpackNode(readSpreadTable(spreadTableIndex));
            }
            midChild = pair.getLeft();
            data = pair.getRight();
            midResolved = true;
        }

        private void resolveHi() {
            if (hiResolved) {
                return;
            }

            if (spreadTableIndex < spreadTableSize - 1) {
                int nextIdx = spreadTableIndex + 1;
                byte nextChar = (byte) (curChar + 1);
                while (readSpreadTable(nextIdx) == 0) {
                    nextIdx++;
                    nextChar++;
                }
                hiChild = new SpreadTrieNode(spreadTableStartOffset, spreadTableSize, nextIdx, nextChar, haveJump);
            } else if (haveJump) {
                Pair<TrieNode, V> pair = unpackNode(nextNodeOffset()); // у spread-ноды по jump'у не должно быть данных
                hiChild = pair.getLeft();
            }
            hiResolved = true;
        }

        @Override
        public ITrieNode<V> getHiChild() {
            resolveHi();
            return hiChild;
        }

        @Override
        public ITrieNode<V> getMidChild() {
            resolveMid();
            return midChild;
        }

        @Override
        public V getData() {
            resolveMid();
            return data;
        }

        @Override
        public byte getLabel() {
            return curChar;
        }
    }

}
