package ru.yandex.webmaster3.core.util.trie;

import org.apache.commons.lang3.tuple.Pair;

import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;

/**
 * @author avhaliullin
 */
public class CompactTrieBuilder<V> implements Iterable<Map.Entry<byte[], V>> {
    TrieNode<V> root;

    public ITrieNode<V> getRoot() {
        return root;
    }

    @Override
    public Iterator<Map.Entry<byte[], V>> iterator() {
        if (root == null) {
            return Collections.emptyIterator();
        }
        return new IteratorImpl<>(root);
    }

    public void insert(byte[] key, V data) {
        if (key.length == 0) {
            throw new IllegalArgumentException("Empty key is not supported");
        }
        if (root == null) {
            root = new TrieNode<>(key[0]);
        }
        TrieNode<V> node = root;
        TrieNode<V> prev = null;
        for (int i = 0; i < key.length; i++) {
            boolean last = i == key.length - 1;
            byte keyByte = key[i];
            while (node.label < keyByte) {
                TrieNode<V> nextNode = node.hiChild;
                if (nextNode == null) {
                    nextNode = new TrieNode<>(keyByte);
                    node.hiChild = nextNode;
                }
                prev = node;
                node = nextNode;
            }
            if (node.label > keyByte) {
                if (prev == null) {
                    root = new TrieNode<>(keyByte);
                    root.hiChild = node;
                    node = root;
                } else if (prev.midChild == node) {
                    prev.midChild = new TrieNode<>(keyByte);
                    prev.midChild.hiChild = node;
                    node = prev.midChild;
                } else if (prev.hiChild == node) {
                    prev.hiChild = new TrieNode<>(keyByte);
                    prev.hiChild.hiChild = node;
                    node = prev.hiChild;
                } else {
                    throw new RuntimeException("Should never happen");
                }
            }
            if (last) {
                node.data = data;
                return;
            } else {
                TrieNode<V> mid = node.midChild;
                if (mid == null) {
                    mid = new TrieNode<>(key[i + 1]);
                    node.midChild = mid;
                }
                prev = node;
                node = mid;
            }
        }
    }

    public V get(byte[] key) {
        TrieNode<V> node = root;
        for (int i = 0; i < key.length && node != null; i++) {
            boolean last = i == key.length - 1;
            byte keyByte = key[i];
            while (node != null && node.label < keyByte) {
                node = node.hiChild;
            }
            if (node == null) {
                return null;
            } else if (!last) {
                node = node.midChild;
            }
        }
        return node == null ? null : node.data;
    }

    static class TrieNode<V> implements ITrieNode<V> {
        TrieNode<V> hiChild;
        TrieNode<V> midChild;
        byte label;
        V data;

        public TrieNode(byte label) {
            this.label = label;
        }

        @Override
        public TrieNode<V> getHiChild() {
            return hiChild;
        }

        @Override
        public TrieNode<V> getMidChild() {
            return midChild;
        }

        @Override
        public byte getLabel() {
            return label;
        }

        @Override
        public V getData() {
            return data;
        }
    }

    private enum Child {
        SELF,
        MID,
        HI,
    }

    private static class TraverseStateEntry<V> {
        private final TrieNode<V> node;
        private final Child child;

        public TraverseStateEntry(TrieNode<V> node, Child child) {
            this.node = node;
            this.child = child;
        }
    }

    private static class IteratorImpl<V> implements Iterator<Map.Entry<byte[], V>> {
        private byte[] buff = new byte[16];
        private final Deque<TraverseStateEntry<V>> stack = new ArrayDeque<>();
        private Map.Entry<byte[], V> next = null;
        private int depth = 1;

        public IteratorImpl(TrieNode<V> root) {
            stack.add(new TraverseStateEntry<>(root, Child.SELF));
        }

        private void assureBufSize(int size) {
            if (buff.length < size) {
                byte[] newBuff = new byte[buff.length * 2];
                System.arraycopy(buff, 0, newBuff, 0, buff.length);
                buff = newBuff;
            }
        }

        @Override
        public boolean hasNext() {
            while (next == null && !stack.isEmpty()) {
                TraverseStateEntry<V> cur = stack.pollLast();
                if (cur.child != null) {
                    switch (cur.child) {
                        case SELF:
                            assureBufSize(depth);
                            buff[depth - 1] = cur.node.label;
                            if (cur.node.data != null) {
                                next = Pair.of(Arrays.copyOf(buff, depth), cur.node.data);
                            }
                            stack.addLast(new TraverseStateEntry<>(cur.node, Child.MID));
                            depth++;
                            break;
                        case MID:
                            stack.addLast(new TraverseStateEntry<>(cur.node, Child.HI));
                            if (cur.node.midChild != null) {
                                stack.addLast(new TraverseStateEntry<>(cur.node.midChild, Child.SELF));
                            }
                            break;
                        case HI:
                            depth--;
                            stack.addLast(new TraverseStateEntry<>(cur.node, null));
                            if (cur.node.hiChild != null) {
                                stack.addLast(new TraverseStateEntry<>(cur.node.hiChild, Child.SELF));
                            }
                            break;
                        default:
                            throw new RuntimeException(cur.child.name());
                    }
                }
            }
            return next != null;
        }

        @Override
        public Map.Entry<byte[], V> next() {
            if (hasNext()) {
                Map.Entry<byte[], V> res = next;
                next = null;
                return res;
            } else {
                throw new NoSuchElementException();
            }
        }
    }
}
