package ru.yandex.webmaster3.core.util.trie;

import java.util.function.BiFunction;

/**
 * @author avhaliullin
 */
public class MergedNodeImpl<V> implements ITrieNode<V> {
    private final BiFunction<V, V, V> merge;
    private final ITrieNode<V> left;
    private final ITrieNode<V> right;
    private ITrieNode<V> hiChild;
    private boolean hiChildResolved;
    private ITrieNode<V> midChild;
    private boolean midChildResolved;
    private V data;
    private boolean resolvedData;

    private MergedNodeImpl(BiFunction<V, V, V> merge, ITrieNode<V> left, ITrieNode<V> right) {
        this.merge = merge;
        this.left = left;
        this.right = right;
    }

    @Override
    public void freeMemory() {
        hiChild = null;
        hiChildResolved = false;
        midChild = null;
        midChildResolved = false;
        data = null;
        resolvedData = false;
    }

    @Override
    public V getData() {
        if (!resolvedData) {
            V lData = left.getData();
            V rData = right.getData();
            if (lData != null) {
                if (rData != null) {
                    data = merge.apply(lData, rData);
                } else {
                    data = lData;
                }
            } else if (rData != null) {
                data = rData;
            }
            resolvedData = true;
        }
        return data;
    }

    @Override
    public ITrieNode<V> getHiChild() {
        if (!hiChildResolved) {
            hiChild = mergeNeighbours(left.getHiChild(), right.getHiChild(), merge);
            hiChildResolved = true;
        }
        return hiChild;
    }

    @Override
    public ITrieNode<V> getMidChild() {
        if (!midChildResolved) {
            midChild = mergeNeighbours(left.getMidChild(), right.getMidChild(), merge);
            midChildResolved = true;
        }
        return midChild;
    }

    private static <V> ITrieNode<V> mergeNeighbours(ITrieNode<V> lNeighbour, ITrieNode<V> rNeighbour, BiFunction<V, V, V> merge) {
        if (lNeighbour == null) {
            return rNeighbour;
        } else {
            if (rNeighbour == null) {
                return lNeighbour;
            } else {
                byte lLabel = lNeighbour.getLabel();
                byte rLabel = rNeighbour.getLabel();
                if (lLabel < rLabel) {
                    rNeighbour = new FakeNode<>(rNeighbour, lLabel);
                } else if (lLabel > rLabel) {
                    lNeighbour = new FakeNode<>(lNeighbour, rLabel);
                }
                return new MergedNodeImpl<>(merge, lNeighbour, rNeighbour);
            }
        }
    }

    @Override
    public byte getLabel() {
        return left.getLabel();
    }

    public static <V> ITrieNode<V> createMerged(ITrieNode<V> left, ITrieNode<V> right, BiFunction<V, V, V> merge) {
        return mergeNeighbours(left, right, merge);
    }

    private static class FakeNode<V> implements ITrieNode<V> {
        private final ITrieNode<V> hiChild;
        private final byte label;

        public FakeNode(ITrieNode<V> hiChild, byte label) {
            this.hiChild = hiChild;
            this.label = label;
        }

        @Override
        public V getData() {
            return null;
        }

        @Override
        public ITrieNode<V> getHiChild() {
            return hiChild;
        }

        @Override
        public ITrieNode<V> getMidChild() {
            return null;
        }

        @Override
        public byte getLabel() {
            return label;
        }
    }
}
