package ru.yandex.metabase.client;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.IntStream;

import ru.yandex.solomon.labels.LabelKeys;
import ru.yandex.solomon.model.protobuf.Label;
import ru.yandex.solomon.util.Escaper;

/**
 * @author Egor Litvinenko
 */
public class MetabasePartitions {

    public static final int MAX_PARTITIONS_POWER = 10;
    public static final int MAX_PARTITIONS = 1 << MAX_PARTITIONS_POWER;

    private static final int MAX_POWER_OF_HASH = 32;
    private static final long MAX_HASH = 1l << MAX_POWER_OF_HASH;

    private static final char ESCAPE_CHAR = '\\';
    private static final Escaper ESCAPER = new Escaper("=&", ESCAPE_CHAR);

    private static final Comparator<Label> LABEL_SORT_BY_KEY = Comparator.comparing(Label::getKey);

    public static Comparator<Label> labelHashComparator() {
        return LABEL_SORT_BY_KEY;
    }

    /**
     * Check that labels are sorted and remove PCS, if needed
     * @param labels
     * @return new list of normalized labels or the same list if labels normalized
     * */
    public static List<Label> normalize(List<Label> labels) {
        Label prev = null;
        for (Label label : labels) {
            boolean mustNormalize = false;
            if (null != prev) {
                if (prev.getKey().compareTo(label.getKey()) > 0) {
                    mustNormalize = true;
                }
            }
            switch (label.getKey()) {
                case LabelKeys.PROJECT:
                case LabelKeys.CLUSTER:
                case LabelKeys.SERVICE:
                    mustNormalize = true;
                    break;
            }
            prev = label;

            if (mustNormalize) {
                final int pcs = 3;
                List<Label> nlabels = new ArrayList<>(Math.max(1, labels.size() - pcs));
                for (Label label1 : labels) {
                    switch (label.getKey()) {
                        case LabelKeys.PROJECT:
                        case LabelKeys.CLUSTER:
                        case LabelKeys.SERVICE:
                            break;
                        default:
                            nlabels.add(label1);
                    }
                }
                nlabels.sort(labelHashComparator());
                return nlabels;
            }
        }
        return labels;
    }

    /**
     * @param labels - normalized labels
     * @return positive long
     * */
    public static long hash(List<Label> labels) {
        if (labels == null || labels.isEmpty()) {
            throw new IllegalArgumentException("labels is empty");
        }
        return Integer.toUnsignedLong(format(labels).hashCode());
    }

    /**
     * @param totalPartitions - total number of partitions for shard
     * @param labels - normalized labels
     * @return partition id for given labels
     * */
    public static int labelsPartition(int totalPartitions, List<Label> labels) {
        if (totalPartitions <= 1) {
            return 0;
        }
        return toPartition(hash(labels), totalPartitions);
    }

    /**
     * @param hash positive long
     * @param totalPartitions total number of partitions for given shard
     * @return partitionId for this hash
     * */
    public static int toPartition(long hash, int totalPartitions) {
        if (totalPartitions <= 1) {
            return 0;
        }
        if (hash < 0) {
            throw new IllegalArgumentException("hash must be positive or zero");
        }
        var power = calcPowerOfTwo(totalPartitions);
        if (power > -1) {
            var floorDiv = hash >> (MAX_POWER_OF_HASH - power);
            return Math.toIntExact(floorDiv);
        } else {
            var partitionCapacity = MAX_HASH / totalPartitions;
            if (partitionCapacity > hash) {
                return 0;
            }
            var floorDiv = Math.floorDiv(hash, partitionCapacity);
            return Math.toIntExact(floorDiv);
        }

    }

    private static String format(List<Label> labels) {
        return format(labels::forEach);
    }

    /**
     * It is copy-pasted from ru.yandex.solomon.coremon.meta.db.ydb.LabelListSortedSerialize
     * */
    private static String format(Consumer<Consumer<Label>> forEach) {
        StringBuilder result = new StringBuilder(256);
        result.append('&');
        forEach.accept(label -> {
            ESCAPER.escapeTo(label.getKey(), result);
            result.append('=');
            ESCAPER.escapeTo(label.getValue(), result);
            result.append('&');
        });
        return result.toString();
    }

    private static final long[] HASH_POWERS_OF_TWO = IntStream.rangeClosed(1, MAX_POWER_OF_HASH).mapToLong(x -> 1l << x).toArray();

    static int calcPowerOfTwo(long x) {
        for (int i = 0; i < HASH_POWERS_OF_TWO.length; i++) {
            if (HASH_POWERS_OF_TWO[i] == x) {
                return i + 1;
            }
        }
        return -1;
    }

}
