package ru.yandex.solomon.util.mh;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.stream.IntStream;

import ru.yandex.commune.mh.MhArrays;
import ru.yandex.commune.mh.MhFields;
import ru.yandex.commune.mh.MhInts;
import ru.yandex.commune.mh.Mhx;

/**
 * @author Stepan Koltsov
 */
public class MultiArrayResizerMh<A> implements MultiArrayResizer<A> {

    private final MethodHandle resizeMhAParam;
    private final MethodHandle resizeMhObjectParam;

    private final MethodHandle capacityMhAParam;
    private final MethodHandle capacityMhObjectParam;

    private final MethodHandle sizeMhAParam;
    private final MethodHandle sizeMhObjectParam;

    private final MethodHandle reserveAdditionalMhAParam;
    private final MethodHandle reserveAdditionalMhObjectParam;

    public final MethodHandle setAtMhAParam;
    private final MethodHandle setAtMhObjectParam;

    public final MethodHandle addMhAParam;

    public MultiArrayResizerMh(Class<A> clazz) {
        MultiArrayResizerCommon<A> common = new MultiArrayResizerCommon<>(clazz);

        try {

            resizeMhAParam = makeResize(common);
            resizeMhObjectParam = MethodHandles.explicitCastArguments(
                resizeMhAParam,
                MethodType.methodType(void.class, Object.class, int.class));

            capacityMhAParam = makeCapacity(common.capacityField);
            capacityMhObjectParam = MethodHandles.explicitCastArguments(
                capacityMhAParam,
                MethodType.methodType(int.class, Object.class));

            sizeMhAParam = makeGetSize(common.sizeField);
            sizeMhObjectParam = MethodHandles.explicitCastArguments(
                sizeMhAParam,
                MethodType.methodType(int.class, Object.class));

            reserveAdditionalMhAParam = makeReserveAdditional(common);
            reserveAdditionalMhObjectParam = MethodHandles.explicitCastArguments(
                reserveAdditionalMhAParam,
                MethodType.methodType(void.class, Object.class, int.class));

            setAtMhAParam = makeSetAt(common);
            setAtMhObjectParam = MethodHandles.explicitCastArguments(
                setAtMhAParam,
                MethodType.methodType(void.class, Object.class, int.class)
                    .appendParameterTypes(Arrays.stream(common.arrayFields).map(f -> Object.class).toArray(Class[]::new))
            );

            addMhAParam = makeAdd(common);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    // A, exactCapacity: int -> ()
    private static MethodHandle makeResize(MultiArrayResizerCommon<?> common) {
        MethodHandle[] fields = Arrays.stream(common.arrayFields)
            .map(MultiArrayResizerMh::makeResizeField)
            .toArray(MethodHandle[]::new);

        MethodHandle r = Mhx.invokeAll(fields);
        return trace("resize", r);
    }

    // A -> int
    private static MethodHandle makeGetSize(Field sizeField) throws Exception {
        MethodHandle r = MethodHandles.lookup().unreflectGetter(sizeField);
        return trace("get size", r);
    }

    // A -> ()
    private static MethodHandle makeIncrementSize(Field sizeField) throws Exception {
        MethodHandle add1 = MethodHandles.insertArguments(MhInts.sum(int.class), 1, 1);
        MethodHandle r = MhFields.updateField(sizeField, add1);
        return trace("inc size", r);
    }

    // A -> int
    private static MethodHandle makeCapacity(Field capacityField) throws Exception {
        // array -> int
        MethodHandle length = MhArrays.arrayLength(capacityField.getType());

        // A -> array
        MethodHandle getter = MethodHandles.lookup().unreflectGetter(capacityField);
        MethodHandle r = MethodHandles.filterReturnValue(getter, length);
        return trace("capacity", r);
    }

    public MethodHandle getResizeMhAParam() {
        return resizeMhAParam;
    }

    public MethodHandle getResizeMhObjectParam() {
        return resizeMhObjectParam;
    }

    public MethodHandle getReserveAdditionalMhAParam() {
        return reserveAdditionalMhAParam;
    }

    public MethodHandle getReserveAdditionalMhObjectParam() {
        return reserveAdditionalMhObjectParam;
    }

    // A, newCapacity: int -> ()
    static MethodHandle makeResizeField(Field field) {
        try {
            Class<?> arrayClass = field.getType();
            Class<?> copyOfType;
            if (arrayClass.getComponentType().isPrimitive()) {
                copyOfType = arrayClass;
            } else {
                copyOfType = Object[].class;
            }

            // a -> array
            MethodHandle getter = MethodHandles.lookup().unreflectGetter(field);
            // a, array -> ()
            MethodHandle setter = MethodHandles.lookup().unreflectSetter(field);

            if (!arrayClass.getComponentType().isPrimitive()) {
                setter = MethodHandles.explicitCastArguments(setter,
                    MethodType.methodType(void.class, field.getDeclaringClass(), Object[].class));
                getter = MethodHandles.explicitCastArguments(getter,
                    MethodType.methodType(Object[].class, field.getDeclaringClass()));
            }

            // array, int -> array
            MethodHandle copyOf = MhArrays.copyOf(copyOfType);

            // a, array, int -> ()
            MethodHandle copyOfThenSetter = MethodHandles.collectArguments(setter, 1, copyOf);

            // array, a, int -> ()
            MethodHandle copyOfThenSetterPermuted = MethodHandles.permuteArguments(copyOfThenSetter,
                MethodType.methodType(void.class, copyOfType, field.getDeclaringClass(), int.class),
                1, 0, 2);

            // a, int -> ()
            MethodHandle r = MethodHandles.foldArguments(copyOfThenSetterPermuted, getter);
            return trace("resize." + field.getName(), r);

        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static MethodHandle trace(String name, MethodHandle methodHandle) {
        if (false) {
            return Mhx.trace(name, methodHandle);
        } else {
            return methodHandle;
        }
    }

    // A, additional: int -> int
    private static MethodHandle makeMinRequiredCapacity(MultiArrayResizerCommon<?> common) throws Exception {
        // a -> int
        MethodHandle capacity = makeGetSize(common.sizeField);
        MethodHandle sum = MhInts.sum(int.class);
        return MethodHandles.collectArguments(sum, 0, capacity);
    }

    // A, newCapacity: int -> bool
    private static MethodHandle makeNeedResize(MultiArrayResizerCommon<?> common) throws Exception {
        // a, int -> int
        MethodHandle minRequiredCapacity = trace("minRequiredCapacity", makeMinRequiredCapacity(common));
        // a -> int
        MethodHandle capacity = makeCapacity(common.capacityField);

        MethodHandle cmp = MhInts.cmp(int.class, MhInts.Cmp.GT);
        cmp = trace("needResize.cmp", cmp);

        // minRequiredCapacity > capacity
        MethodHandle interm = Mhx.filterArguments(cmp, 0, minRequiredCapacity, capacity);

        MethodHandle r = MethodHandles.permuteArguments(
            interm, MethodType.methodType(boolean.class, common.clazz, int.class),
            0, 1, 0);
        return trace("needResize", r);
    }

    // A, additional: int -> ()
    private static MethodHandle makeResizeAdditional(MultiArrayResizerCommon<?> common) throws Exception {
        // a, int -> ()
        MethodHandle resize = makeResize(common);

        // a, a, int -> int
        MethodHandle newCapacityForObject = Mhx.filterArguments(newCapacity, 0, makeGetSize(common.sizeField), makeCapacity(common.capacityField));

        // a, a, a, int -> ()
        MethodHandle interm = Mhx.filterArguments(resize, 0, MethodHandles.identity(common.clazz), newCapacityForObject);

        return MethodHandles.permuteArguments(interm,
            MethodType.methodType(void.class, common.clazz, int.class),
            0, 0, 0, 1);
    }

    private static final MethodHandle newCapacity;
    static {
        try {
            newCapacity = MethodHandles.lookup().findStatic(
                MultiArrayResizerCommon.class, "newCapacityImpl", MethodType.methodType(int.class, int.class, int.class, int.class));
        } catch (NoSuchMethodException | IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }


    // A, int -> ()
    private static MethodHandle makeReserveAdditional(MultiArrayResizerCommon<?> common) throws Exception {
        MethodHandle needResize = makeNeedResize(common);
        MethodHandle resizeAdditional = makeResizeAdditional(common);
        MethodHandle nop = Mhx.nop(common.clazz, int.class);
        return MethodHandles.guardWithTest(needResize, resizeAdditional, nop);
    }


    // A, int, f... -> ()
    private static MethodHandle makeSetFieldAt(MultiArrayResizerCommon<?> common, int fieldIndex) throws Exception {
        // A -> f[]
        MethodHandle getArray = MethodHandles.lookup().unreflectGetter(common.arrayFields[fieldIndex]);
        // f[], int, f -> ()
        MethodHandle setInArray = MethodHandles.arrayElementSetter(common.arrayFields[fieldIndex].getType());

        // A, int, f -> ()
        MethodHandle getArrayAndSetInArray = MethodHandles.filterArguments(setInArray, 0, getArray);

        MethodType resultType = MethodType.methodType(void.class, common.clazz, int.class)
            .appendParameterTypes(common.arrayElementTypes);

        int[] permute = new int[] { 0, 1, 2 + fieldIndex };

        return MethodHandles.permuteArguments(getArrayAndSetInArray, resultType, permute);
    }

    // A, int, f... -> ()
    private static MethodHandle makeSetAt(MultiArrayResizerCommon<?> common) throws Exception {
        MethodHandle[] setters = IntStream.range(0, common.arrayFields.length)
            .mapToObj(i -> {
                try {
                    return makeSetFieldAt(common, i);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            })
            .toArray(MethodHandle[]::new);
        return Mhx.invokeAll(setters);
    }


    // A, f... -> ()
    private static MethodHandle makeAdd(MultiArrayResizerCommon<?> common) throws Exception {
        // A -> ()
        MethodHandle reserveAdditional = MethodHandles.insertArguments(makeReserveAdditional(common), 1, 1);

        // A, int, f... -> ()
        MethodHandle setAt = makeSetAt(common);

        // A, A, f... -> ()
        MethodHandle setAtSize0 = MethodHandles.filterArguments(setAt, 1, makeGetSize(common.sizeField));

        // A, f... -> ()
        MethodHandle setAtSize = Mhx.duplicateArgument(setAtSize0, 0);

        // A -> ()
        MethodHandle incSize = makeIncrementSize(common.sizeField);

        return Mhx.invokeAllWithUnionParams(reserveAdditional, setAtSize, incSize);
    }


    @Override
    public void resize(A arrays, int newSize) {
        try {
            resizeMhObjectParam.invokeExact(arrays, newSize);
        } catch (Throwable throwable) {
            throw new RuntimeException(throwable);
        }
    }

    @Override
    public int capacity(A arrays) {
        try {
            return (int) capacityMhObjectParam.invokeExact(arrays);
        } catch (Throwable throwable) {
            throw new RuntimeException(throwable);
        }
    }

    @Override
    public int size(A arrays) {
        try {
            return (int) sizeMhObjectParam.invokeExact(arrays);
        } catch (Throwable throwable) {
            throw new RuntimeException(throwable);
        }
    }

    @Override
    public void reserveAdditional(A arrays, int additional) {
        try {
            reserveAdditionalMhObjectParam.invokeExact(arrays, additional);
        } catch (Throwable throwable) {
            throw new RuntimeException(throwable);
        }
    }

}
