package ru.yandex.collection;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate;

import ru.yandex.function.GenericBiConsumer;
import ru.yandex.function.GenericBiFunction;
import ru.yandex.function.GenericFunction;

public class PatternMap<T extends PatternSample, V> {
    public static final String ASTERISK = "*";

    private final Map<String, List<Entry<T, V>>> exact = new LinkedHashMap<>();
    private final Map<String, List<Entry<T, V>>> prefix =
        new TreeMap<>(StringLengthComparator.INSTANCE);
    private V asterisk;

    public PatternMap() {
        this(null);
    }

    public PatternMap(final V asterisk) {
        this.asterisk = asterisk;
    }

    public boolean asteriskOnly() {
        return exact.isEmpty() && prefix.isEmpty();
    }

    private Map<String, List<Entry<T, V>>> selectMap(
        final Pattern<T> pattern)
    {
        if (pattern.prefix()) {
            return prefix;
        } else {
            return exact;
        }
    }

    public V put(final Pattern<T> pattern, final V obj) {
        if (obj == null) {
            return remove(pattern);
        } else if (pattern.isAsterisk()) {
            V result = asterisk;
            asterisk = obj;
            return result;
        } else {
            Map<String, List<Entry<T, V>>> map = selectMap(pattern);
            String path = pattern.path();
            List<Entry<T, V>> pathEntries =
                map.computeIfAbsent(path, x -> new ArrayList<>());
            Predicate<T> predicate = pattern.predicate();
            int size = pathEntries.size();
            for (int i = 0; i < size; ++i) {
                Entry<T, V> entry = pathEntries.get(i);
                if (predicate.equals(entry.predicate)) {
                    V oldValue = entry.value;
                    entry.value = obj;
                    return oldValue;
                }
            }
            pathEntries.add(new Entry<>(predicate, obj));
            return null;
        }
    }

    public V remove(final Pattern<T> pattern) {
        V result;
        if (pattern.isAsterisk()) {
            result = asterisk;
            asterisk = null;
        } else {
            Map<String, List<Entry<T, V>>> map = selectMap(pattern);
            String path = pattern.path();
            List<Entry<T, V>> pathEntries = map.get(path);
            result = null;
            if (pathEntries != null) {
                Predicate<T> predicate = pattern.predicate();
                int size = pathEntries.size();
                for (int i = 0; i < size; ++i) {
                    Entry<T, V> entry = pathEntries.get(i);
                    if (predicate.equals(entry.predicate)) {
                        result = entry.value;
                        if (size == 1) {
                            map.remove(path);
                        } else {
                            pathEntries.remove(i);
                        }
                        break;
                    }
                }
            }
        }
        return result;
    }

    public V get(final Pattern<T> pattern) {
        V result = asterisk;
        if (!pattern.isAsterisk()) {
            Map<String, List<Entry<T, V>>> map = selectMap(pattern);
            List<Entry<T, V>> pathEntries = map.get(pattern.path());
            if (pathEntries != null) {
                Predicate<T> predicate = pattern.predicate();
                int size = pathEntries.size();
                for (int i = 0; i < size; ++i) {
                    Entry<T, V> entry = pathEntries.get(i);
                    if (predicate.equals(entry.predicate)) {
                        result = entry.value;
                        break;
                    }
                }
            }
        }
        return result;
    }

    public V get(final T sample) {
        List<String> paths = sample.paths();
        int pathsSize = paths.size();
        for (int i = 0; i < pathsSize; ++i) {
            List<Entry<T, V>> pathEntries = exact.get(paths.get(i));
            if (pathEntries != null) {
                int size = pathEntries.size();
                for (int j = 0; j < size; ++j) {
                    Entry<T, V> entry = pathEntries.get(j);
                    if (entry.predicate.test(sample)) {
                        return entry.value;
                    }
                }
            }
        }

        if (!prefix.isEmpty()) {
            for (int i = 0; i < pathsSize; ++i) {
                String path = paths.get(i);
                for (Map.Entry<String, List<Entry<T, V>>> entry
                    : prefix.entrySet())
                {
                    if (path.startsWith(entry.getKey())) {
                        List<Entry<T, V>> pathEntries = entry.getValue();
                        int size = pathEntries.size();
                        for (int j = 0; j < size; ++j) {
                            Entry<T, V> pathEntry = pathEntries.get(j);
                            if (pathEntry.predicate.test(sample)) {
                                return pathEntry.value;
                            }
                        }
                    }
                }
            }
        }
        return asterisk;
    }

    public V asterisk() {
        return asterisk;
    }

    private static <T, V, U, E extends Exception> void transform(
        final Map<String, List<Entry<T, V>>> src,
        final Map<String, List<Entry<T, U>>> dst,
        final GenericFunction<? super V, ? extends U, E> transformer)
        throws E
    {
        for (Map.Entry<String, List<Entry<T, V>>> entry: src.entrySet()) {
            List<Entry<T, V>> srcEntries = entry.getValue();
            int size = srcEntries.size();
            List<Entry<T, U>> dstEntries = new ArrayList<>(size);
            dst.put(entry.getKey(), dstEntries);
            for (int i = 0; i < size; ++i) {
                Entry<T, V> pathEntry = srcEntries.get(i);
                dstEntries.add(
                    new Entry<>(
                        pathEntry.predicate,
                        transformer.apply(pathEntry.value)));
            }
        }
    }

    public <U, E extends Exception> PatternMap<T, U> transform(
        final GenericFunction<? super V, ? extends U, E> transformer)
        throws E
    {
        PatternMap<T, U> mapper = new PatternMap<>();
        transform(exact, mapper.exact, transformer);
        transform(prefix, mapper.prefix, transformer);
        if (asterisk != null) {
            mapper.asterisk = transformer.apply(asterisk);
        }
        return mapper;
    }

    private static <T, V, U, E extends Exception> void transform(
        final Map<String, List<Entry<T, V>>> src,
        final Map<String, List<Entry<T, U>>> dst,
        final GenericBiFunction<? super Pattern<T>, ? super V, ? extends U, E> transformer,
        final boolean prefix)
        throws E
    {
        for (Map.Entry<String, List<Entry<T, V>>> entry: src.entrySet()) {
            List<Entry<T, V>> srcEntries = entry.getValue();
            int size = srcEntries.size();
            List<Entry<T, U>> dstEntries = new ArrayList<>(size);
            String path = entry.getKey();
            dst.put(path, dstEntries);
            for (int i = 0; i < size; ++i) {
                Entry<T, V> pathEntry = srcEntries.get(i);
                dstEntries.add(
                    new Entry<>(
                        pathEntry.predicate,
                        transformer.apply(
                            new Pattern<>(path, prefix, pathEntry.predicate),
                            pathEntry.value)));
            }
        }
    }

    public <U, E extends Exception> PatternMap<T, U> transform(
        final GenericBiFunction<? super Pattern<T>, ? super V, ? extends U, E> transformer)
        throws E
    {
        PatternMap<T, U> mapper = new PatternMap<>();
        transform(exact, mapper.exact, transformer, false);
        transform(prefix, mapper.prefix, transformer, true);
        if (asterisk != null) {
            mapper.asterisk =
                transformer.apply(new Pattern<>("", true), asterisk);
        }
        return mapper;
    }

    private static <T, V, E extends Exception> void traverse(
        final GenericBiConsumer<Pattern<T>, V, E> visitor,
        final Map<String, List<Entry<T, V>>> map,
        final boolean prefix)
        throws E
    {
        for (Map.Entry<String, List<Entry<T, V>>> entry: map.entrySet()) {
            String path = entry.getKey();
            List<Entry<T, V>> pathEntries = entry.getValue();
            int size = pathEntries.size();
            for (int i = 0; i < size; ++i) {
                Entry<T, V> pathEntry = pathEntries.get(i);
                visitor.accept(
                    new Pattern<>(path, prefix, pathEntry.predicate),
                    pathEntry.value);
            }
        }
    }

    public <E extends Exception> void traverse(
        final GenericBiConsumer<Pattern<T>, V, E> visitor)
        throws E
    {
        traverse(visitor, exact, false);
        traverse(visitor, prefix, true);
        if (asterisk != null) {
            visitor.accept(new Pattern<>("", true), asterisk);
        }
    }

    private static class Entry<T, V> {
        private final Predicate<T> predicate;
        private V value;

        Entry(final Predicate<T> predicate, final V value) {
            this.predicate = predicate;
            this.value = value;
        }
    }

    @Override
    public String toString() {
        return "PatternMap{" +
            "exact=" + exact +
            ", prefix=" + prefix +
            ", asterisk=" + asterisk +
            '}';
    }
}

