import math
import argparse

from scipy.stats import binom
import numpy as np


nth_root = lambda number, power: number ** (1/power)


def compute_shards_number(shardnum, total_sla, active_replicas, replica_level, infra_reliability):
    individual_shard_reliability = nth_root(total_sla/100, shardnum) * 100
    return compute_replicas_number(active_replicas, replica_level, individual_shard_reliability, infra_reliability)


def compute_replicas_number(active_replicas, replica_level, sla_level, infra_reliability):
    bad_params_err = "Incorrect probabilities provided"

    if infra_reliability > 1:
        infra_reliability = infra_reliability / 100

    if infra_reliability <= 0 or infra_reliability >= 1:
        raise Exception(bad_params_err)

    replica_level, sla_level = (replica_level / 100) * infra_reliability, sla_level / 100

    if sla_level <= 0 or sla_level >= 1 or replica_level <= 0 or replica_level >= 1:
        raise Exception(bad_params_err)

    initial_step = math.ceil(active_replicas * (1 - replica_level))
    number_of_steps = max(15, math.ceil(np.log2(active_replicas)))
    blinking, previous_blinking, blinking_count = False, False, 0
    starting_replicas = active_replicas
    steps_taken = 0

    while True:
        probability = 1 - binom.cdf(active_replicas - 1, starting_replicas, replica_level)

        if blinking != previous_blinking:
            blinking_count += 1

        if (steps_taken > number_of_steps or blinking_count > 5) and probability > sla_level:
            return starting_replicas

        previous_blinking = blinking

        if probability < sla_level:
            starting_replicas += initial_step
        elif probability > sla_level:
            starting_replicas -= initial_step
            blinking = True
        else:
            return starting_replicas

        initial_step = math.ceil(initial_step / 2)

        steps_taken += 1


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--alive_required", type=int, help="How many replicas should be alive in each shard")
    parser.add_argument("--single_replica_reliability", type=float, help="Replica individual reliability")
    parser.add_argument("--sla_required", type=float, help="Required SLA")
    parser.add_argument("--shard_number", type=int, help="Number of shards required for service to operate")
    parser.add_argument("--infra_reliability", type=float, help="Infra total relialability", default=0.96)
    args = parser.parse_args()

    requied_replicas_number = compute_shards_number(args.shard_number, args.sla_required, args.alive_required,
                                                    args.single_replica_reliability, args.infra_reliability)

    print(f"{requied_replicas_number} replicas per shard are needed to satisfy requirements")


if __name__ == '__main__':
    main()
