package ru.yandex.concurrent;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BinaryOperator;
import java.util.function.Predicate;

import sun.misc.Unsafe;

public class ParallelHashMap<K, V> {
    private static final Unsafe UNSAFE;
    static {
        try {
            Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
            theUnsafe.setAccessible(true);
            UNSAFE = (Unsafe) theUnsafe.get(null);
        } catch (IllegalAccessException | NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }

    private final int backets;
    private final int mask;
    private final List<Map<K, V>> maps;
    // On call to removeAnyValue we don't want to always start from the same
    // backet, so we use iteration seed which incremented on any removeAnyValue
    // We don't really care about fairness here, so its modification is not
    // synchronized
    private int iterationSeed = 0;

    public ParallelHashMap(final int backets) {
        this.backets = Integer.highestOneBit(Math.max((backets << 1) - 1, 1));
        mask = this.backets - 1;
        maps = new ArrayList<>(this.backets);
        for (int i = 0; i < this.backets; ++i) {
            this.maps.add(new HashMap<>());
        }
    }

    public int size() {
        UNSAFE.loadFence();
        int size = 0;
        for (int i = 0; i < backets; ++i) {
            size += maps.get(i).size();
        }
        return size;
    }

    public V get(final K key) {
        Map<K, V> map = maps.get(key.hashCode() & mask);
        synchronized (map) {
            return map.get(key);
        }
    }

    public V put(final K key, final V value) {
        Map<K, V> map = maps.get(key.hashCode() & mask);
        synchronized (map) {
            return map.put(key, value);
        }
    }

    public V putIfAbsent(final K key, final V value) {
        Map<K, V> map = maps.get(key.hashCode() & mask);
        synchronized (map) {
            return map.putIfAbsent(key, value);
        }
    }

    public V putIfAbsent(
        final K key,
        final V value,
        final BinaryOperator<V> oldValueRemapper)
    {
        Map<K, V> map = maps.get(key.hashCode() & mask);
        synchronized (map) {
            return oldValueRemapper.apply(value, map.putIfAbsent(key, value));
        }
    }

    public V remove(final K key) {
        Map<K, V> map = maps.get(key.hashCode() & mask);
        synchronized (map) {
            return map.remove(key);
        }
    }

    public V removeAnyValue() {
        UNSAFE.loadFence();
        int seed = ++iterationSeed;
        for (int i = 0; i < backets; ++i) {
            Map<K, V> map = maps.get((i + seed) & mask);
            if (!map.isEmpty()) {
                synchronized (map) {
                    if (!map.isEmpty()) {
                        Iterator<V> iter = map.values().iterator();
                        V value = iter.next();
                        iter.remove();
                        return value;
                    }
                }
            }
        }
        return null;
    }

    public V removeAnyValueMatching(final Predicate<V> predicate) {
        UNSAFE.loadFence();
        int seed = ++iterationSeed;
        for (int i = 0; i < backets; ++i) {
            Map<K, V> map = maps.get((i + seed) & mask);
            if (!map.isEmpty()) {
                synchronized (map) {
                    if (!map.isEmpty()) {
                        Iterator<V> iter = map.values().iterator();
                        V value = iter.next();
                        if (predicate.test(value)) {
                            iter.remove();
                            return value;
                        }
                    }
                }
            }
        }
        return null;
    }
}

