package ru.yandex.solomon.coremon.balancer.state;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

import javax.annotation.concurrent.Immutable;

import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets;

import ru.yandex.solomon.core.db.model.ShardPartitionId;

/**
 * @author Sergey Polovko
 */
@Immutable
public final class ShardIds {
    public static final ShardIds EMPTY = new ShardIds(Collections.emptySet());

    private final Set<ShardPartitionId> shardsPartitions;
    private final long hash;
    private final int shardsCount;

    private ShardIds(Set<ShardPartitionId> shardsPartitions) {
        this.shardsPartitions = shardsPartitions;
        this.hash = calcHashCode(shardsPartitions);
        var shards = new IntOpenHashSet();
        for (var key : shardsPartitions) {
            shards.add(key.shardId());
        }
        this.shardsCount = shards.size();
    }

    @Deprecated
    public static ShardIds ofWholeShards(int... shardIds) {
        var shards = new HashSet<ShardPartitionId>(shardIds.length);
        for (int id : shardIds) {
            shards.add(ShardPartitionId.onePartitionShard(id));
        }
        return new ShardIds(shards);
    }

    @Deprecated
    public static ShardIds ofWholeShards(IntSet shardIds) {
        var shards = new HashSet<ShardPartitionId>(shardIds.size());
        for (var id : shardIds) {
            shards.add(ShardPartitionId.onePartitionShard(id));
        }
        return new ShardIds(shards);
    }

    public static ShardIds ofPartitions(ShardPartitionId... partitions) {
        return new ShardIds(new HashSet<>(Arrays.asList(partitions)));
    }

    public boolean isEmpty() {
        return shardsCount == 0;
    }

    public int size() {
        return shardsCount;
    }

    public long getHash() {
        return hash;
    }

    public boolean containsShard(int shardId) {
        return shardsPartitions.contains(ShardPartitionId.onePartitionShard(shardId));
    }

    public boolean containsPartition(ShardPartitionId key) {
        return shardsPartitions.contains(key);
    }

    @Deprecated
    public IntSet getShards() {
        var shards = new IntOpenHashSet(shardsPartitions.size());
        shardsPartitions.stream().forEach(key -> {
            shards.add(key.shardId());
        });
        return IntSets.unmodifiable(shards);
    }

    @Deprecated
    public ShardIds addRemoveShards(IntSet add, IntSet remove) {
        if (add.isEmpty() && remove.isEmpty()) {
            return this;
        }

        var copy = new HashSet<ShardPartitionId>(shardsCount);
        for (var shard : shardsPartitions) {
            if (!remove.contains(shard.shardId())) {
                copy.add(shard);
            }
        }
        copy.addAll(add.stream()
                .map(shardId -> ShardPartitionId.onePartitionShard(shardId))
                .collect(Collectors.toList())
        );
        return new ShardIds(copy);
    }

    public ShardIds addRemovePartitions(Set<ShardPartitionId> add, Set<ShardPartitionId> remove) {
        if (add.isEmpty() && remove.isEmpty()) {
            return this;
        }

        var copy = new HashSet<>(shardsPartitions);
        copy.removeAll(remove);
        copy.addAll(add);
        return new ShardIds(copy);
    }

    private static long calcHashCode(Set<ShardPartitionId> shardIds) {
        return HashCommon.murmurHash3((long) shardIds.hashCode());
    }

    @Override
    public String toString() {
        var sb = new StringBuilder("ShardIds");
        sb.append("{hash=").append(Long.toHexString(hash));
        sb.append(", ids=[");

        var it = shardsPartitions.iterator();
        for (int i = 0; i < 5 && it.hasNext(); i++) {
            if (i > 0) {
                sb.append(", ");
            }
            sb.append(it.next().shardId());
        }
        if (it.hasNext()) {
            sb.append(", ... (").append(shardsCount - 5).append(" more)");
        }
        sb.append("]}");

        return sb.toString();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        var that = (ShardIds) o;
        if (hash != that.hash) return false;
        return shardsPartitions.equals(that.shardsPartitions);
    }

    @Override
    public int hashCode() {
        return Long.hashCode(hash);
    }
}
