package ru.yandex.metabase.client.impl;

import java.util.BitSet;
import java.util.stream.IntStream;

import javax.annotation.Nullable;

import ru.yandex.bolts.function.Function;
import ru.yandex.solomon.labels.shard.ShardKey;

/**
 * @author Egor Litvinenko
 * */
public class MetabaseShard implements PartitionedShard {

    private final long createdAt;
    private final ShardKey key;
    private final int numId;
    private final long generationId;

    private final String[] partitionFqdns;

    private BitSet partitionIsReady;
    private BitSet partitionAllowsNew;
    private boolean shardIsReady = false;
    private boolean shardIsAllowNew = false;

    public MetabaseShard(
            long createdAt,
            ShardKey key,
            int numId,
            long generationId,
            int totalPartitions)
    {
        this(createdAt, key, numId, new String[totalPartitions],
                new BitSet(totalPartitions), new BitSet(totalPartitions),
                generationId);
    }

    private MetabaseShard(
            long createdAt,
            ShardKey key,
            int numId,
            String[] partitionFqdnIndex,
            BitSet partitionIsReady,
            BitSet partitionAllowsNew,
            long generationId)
    {
        this.createdAt = createdAt;
        this.key = key;
        this.numId = numId;
        this.partitionFqdns = partitionFqdnIndex;
        this.partitionIsReady = partitionIsReady;
        this.partitionAllowsNew = partitionAllowsNew;
        this.generationId = generationId;
    }

    @Override
    public ShardKey getKey() {
        return key;
    }

    @Override
    public int getNumId() {
        return numId;
    }

    @Nullable
    @Override
    public String getFqdn(int partitionId) {
        return partitionFqdns[partitionId];
    }

    @Override
    public long getGenerationId() {
        return generationId;
    }

    @Override
    public int getTotalPartitions() {
        return partitionFqdns.length;
    }

    public int getDiscoveredPartitions() {
        int discoveredPartitions = 0;
        for (int i = 0; i < partitionFqdns.length; i++) {
            if (partitionFqdns[i] != null) {
                discoveredPartitions++;
            }
        }
        return discoveredPartitions;
    }

    public boolean isReady() {
        return shardIsReady;
    }

    public boolean isAllowNew() {
        return shardIsAllowNew;
    }

    @Override
    public PartitionKey withOnePartition(int partitionId) {
        return new PartitionKey(this, partitionId);
    }

    @Override
    public IntStream partitionStream() {
        final var ready = this.partitionIsReady;
        final int cardinality = shardIsReady ? getTotalPartitions() : ready.cardinality();
        if (cardinality == getTotalPartitions()) {
            return IntStream.range(0, getTotalPartitions());
        }
        if (cardinality == 1) {
            return IntStream.of(ready.nextSetBit(0));
        }
        int seed = ready.nextSetBit(0);
        return IntStream.iterate(seed, prevPartition -> ready.nextSetBit(prevPartition + 1)).limit(cardinality);
    }

    void addPartition(long generationId, int partitionId, String fqdn, boolean ready, boolean allowNew) {
        if (generationId < this.generationId || discovered(partitionId)) {
            return;
        }
        partitionFqdns[partitionId] = fqdn;
        if (allowNew) {
            partitionAllowsNew.set(partitionId);
        }
        if (ready) {
            partitionIsReady.set(partitionId);
        }
        tryFreeze();
    }

    private boolean discovered(int partitionId) {
        return partitionFqdns[partitionId] != null;
    }

    private boolean allDiscovered() {
        for (int i = 0; i < partitionFqdns.length; i++) {
            if (partitionFqdns[i] == null) {
                return false;
            }
        }
        return true;
    }

    private void tryFreeze() {
        if (allDiscovered()) {
            // we've discovered all partitions in this shard
            partitionAllowsNew = MyInterners.bitset().intern(partitionAllowsNew, getTotalPartitions());
            partitionIsReady = MyInterners.bitset().intern(partitionIsReady, getTotalPartitions());
            shardIsReady = partitionIsReady.cardinality() == getTotalPartitions();
            shardIsAllowNew = partitionAllowsNew.cardinality() == getTotalPartitions();
        }
    }

    @Override
    public String toString() {
        return "Shard{" +
                "numId=" + Integer.toUnsignedLong(numId) +
                ", key='" + getKey() + '\'' +
                ", totalPartitions=" + getTotalPartitions() +
                ", generationId=" + getGenerationId() +
                ", ready=" + isReady() +
                ", allowNew=" + isAllowNew() +
                '}';
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        MetabaseShard shard = (MetabaseShard) o;
        return numId == shard.numId;
    }

    @Override
    public int hashCode() {
        return numId;
    }

    /**
     * use for merging different statuses from one node (different genId) or different nodes (different time)
     * choose max by generationId.
     * If has one generationId, compare totalPartitions by time
     * */
    static MetabaseShard mergeShardInfoFromServers(MetabaseShard left, MetabaseShard right) {
        if (left.generationId > right.generationId) {
            return left;
        }
        else if (left.generationId < right.generationId) {
            return right;
        }
        assert left.getTotalPartitions() == right.getTotalPartitions();
        // Different servers can contain information about the same partition when shard moved etc
        // we should choose latest fqdn for partition
        var latest = left.createdAt < right.createdAt ? right : left;
        var oldest = latest == left ? right : left;
        final MetabaseShard result = new MetabaseShard(
                latest.createdAt,
                latest.key,
                latest.numId,
                latest.generationId,
                latest.getTotalPartitions()
        );

        for (int i = 0; i < result.getTotalPartitions(); i++) {
            if (latest.discovered(i)) {
                result.partitionFqdns[i] = latest.partitionFqdns[i];
            } else {
                if (i < oldest.getTotalPartitions() && oldest.discovered(i)) {
                    result.partitionFqdns[i] = oldest.partitionFqdns[i];
                }
            }
        }

        moveBits(latest, oldest, result, l -> l.partitionIsReady);
        moveBits(latest, oldest, result, l -> l.partitionAllowsNew);

        result.tryFreeze();
        return result;
    }

    private static void moveBits(
            MetabaseShard latest,
            MetabaseShard oldest,
            MetabaseShard result,
            Function<MetabaseShard, BitSet> getter
    ) {
        var latestSet = getter.apply(latest);
        var oldestSet = getter.apply(oldest);
        var resultSet = getter.apply(result);
        for (int i = latestSet.nextSetBit(0); i >= 0; i = latestSet.nextSetBit(i + 1)) {
            resultSet.set(i);
        }
        for (int i = oldestSet.nextSetBit(0); i >= 0 && i < result.getTotalPartitions(); i = oldestSet.nextSetBit(i + 1)) {
            if (!latest.discovered(i)) {
                resultSet.set(i);
            }
        }
    }

}
