package ru.yandex.matrixnet;

import java.util.Arrays;
import java.util.function.Predicate;

public class BoostableMatrixnetModel extends BasicMatrixnetModel {
    private static final int DEFAULT_BOOST = 10;
    private final int[] indicies;
    private final Predicate<double[]> predicate;
    private final FactorBooster[] boosters;

    public BoostableMatrixnetModel(
        final ImmutableMatrixnetModelConfig config,
        final int[] indicies,
        final Predicate<double[]> predicate)
        throws MatrixnetModelParseException
    {
        super(config);

        this.predicate = predicate;
        this.indicies = Arrays.copyOf(indicies, indicies.length);
        this.boosters = new FactorBooster[indicies.length];

        for (int i = 0; i < indicies.length; i++) {
            this.boosters[i] = new LinearPositiveBooster(DEFAULT_BOOST);
        }
    }

    public BoostableMatrixnetModel(
        final ImmutableMatrixnetModelConfig config,
        final int[] indicies)
        throws MatrixnetModelParseException
    {
        this(config, indicies, null);
    }

    @Override
    public double score(final double[] factors) {
        double score = super.score(factors);
        if (predicate == null || predicate.test(factors)) {
            for (int i = 0; i < indicies.length; i++) {
                score = boosters[i].apply(score, factors[indicies[i]]);
            }
        }

        return score;
    }

    public interface FactorBooster {
        double apply(final double score, final double factor);
    }

    public static class LinearPositiveBooster implements FactorBooster {
        private final int boost;

        public LinearPositiveBooster(final int boost) {
            this.boost = boost;
        }

        @Override
        public double apply(final double score, final double factor) {
            return score + boost * factor;
        }
    }
}
