import argparse

import sys


def read_top(filename):
    stage_to_rate = {}
    with open(filename, 'r') as reader:
        for _, line in enumerate(reader):
            parts = line.split(' ')

            name = parts[0]
            rate = float(parts[3][:-1])

            stage_to_rate[name] = rate

    return stage_to_rate


byte_key = 'rate'
message_key = 'messages_rate'
current_prefix = 'current_'
override_key = 'overridden'


def generate_stage_limits(byte_rates_file_name,
                          message_rates_file_name,
                          default_byte_limit,
                          default_message_limit,
                          min_override_percentage,
                          override_multiplier):
    def create_limits(byte_limit, message_limit):
        return {
            byte_key: byte_limit,
            message_key: message_limit
        }

    default_limits = create_limits(default_byte_limit, default_message_limit)

    stage_limits = {}

    byte_rates = read_top(byte_rates_file_name)
    message_rates = read_top(message_rates_file_name)

    def get_limits(stage):
        limits = stage_limits.get(stage)
        if limits is None:
            limits = stage_limits[stage] = create_limits(
                default_byte_limit, default_message_limit
            )

        return limits

    def override_for_top(top_rates, limit_key):
        for stage, rate in top_rates.items():
            limits_for_stage = get_limits(stage)
            limits_for_stage[current_prefix + limit_key] = rate
            if rate >= min_override_percentage * default_limits[limit_key]:
                limits_for_stage[limit_key] = rate * override_multiplier
                limits_for_stage[override_key] = True

    override_for_top(message_rates, message_key)
    override_for_top(byte_rates, byte_key)

    return stage_limits, default_limits


BYTE = 1
KILOBYTE = 1024 * BYTE
MEGABYTE = 1024 * KILOBYTE
GIGABYTE = 1024 * MEGABYTE
TERABYTE = 1024 * GIGABYTE
PETABYTE = 1024 * TERABYTE

sizes = [BYTE, KILOBYTE, MEGABYTE, GIGABYTE, TERABYTE, PETABYTE]
suffixes = ['b', 'kb', 'mb', 'gb', 'tb', 'pb']

suffix_to_size = dict()
for i in range(len(sizes)):
    suffix_to_size[suffixes[i]] = sizes[i]


def from_byte_config(byte_string):
    byte_string = byte_string.lower()
    for suffix, size in suffix_to_size.items():
        byte_suffix = byte_string[-len(suffix):]
        byte_prefix = byte_string[:-len(suffix)]
        if suffix == byte_suffix and byte_prefix.isdigit():
            return int(byte_prefix) * size

    return -1


def to_byte_config(byte_rate):
    byte_rate = int(byte_rate)
    for i in range(len(sizes) - 1):
        if byte_rate < sizes[i + 1]:
            rate = (byte_rate - 1) // sizes[i] + 1
            return '"{rate}{suffix}"'.format(rate=rate, suffix=suffixes[i])


def print_stage_limits(stage_limits, default_limits, time_period):
    def print_with_current(limits_for_stage, key, transform):
        current_key = current_prefix + key
        current_value = limits_for_stage.get(current_key, None)

        if current_value is not None:
            current_value = transform(current_value)
        else:
            current_value = '"Unknown"'

        print("\t\t// current {key} = {value}".format(key=key, value=current_value))

        soft_value = transform(default_limits[key])
        print("\t\t// default limit = {value}".format(value=soft_value))

        value = transform(limits_for_stage[key])
        print("\t\t{key} = {value}".format(key=key, value=value))

    stages = sorted(stage_limits.keys())

    print('white_list_limits = {')
    for stage in stages:
        limits = stage_limits[stage]
        if limits.get(override_key) is None:
            continue

        print('\t{stage} ='.format(stage=stage), '{')
        print("\t\t// data from {start} - {end}".format(
            start=time_period['start'], end=time_period['end']
        ))
        print()

        print_with_current(limits, byte_key, to_byte_config)
        print_with_current(limits, message_key, int)

        print('\t},')
    print('}')

###
# Creates config file with overriden throttling limits for DUs with high bytes/messages rates in logs
# Required parameters are two files with information about bytes and messages and time period start/end
#
# Using example:
#
# python3 configs/throttling_white_list_generator.py '/home/name/bytes.txt' '/home/name/messages.txt' 05.07.21 05.08.21 > ~/throttling_white_list.conf
#
# Processed files are taken in the same format as provided in DEPLOY-4010:
# https://st.yandex-team.ru/DEPLOY-4010#61110818b232f55213c93306
#
# bytes:
# stage_id.du_id Total logs: 100500.5, Pods total/affected: 10/8 (80.0%) - True
#
# messages:
# stage_id.du_id Total logs: 123456.7, Pods total/affected: 10/5 (50.0%) - True
###

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Manage throttling white list generation')
    parser.add_argument("bytes_file", help="Path to file with bytes / sec")
    parser.add_argument("messages_file", help="Path to file with messages / sec")
    parser.add_argument("time_start", help="Start of data collection period (example: '05.07.21')")
    parser.add_argument("time_end", help="End of data collection period (example: '05.08.21')")
    parser.add_argument('-bl', "--bytes_limit", default='15mb', help="Default bytes limit (default: 15mb)")
    parser.add_argument('-ml', "--messages_limit", default=20000, help="Default messages limit (default: 20000)")
    parser.add_argument('-op', "--min_override_percentage",
                        default=0.9,
                        help="if stage_limit * min_override_percentage >= default_limit - override (default: 0.9)")
    parser.add_argument('-om', "--override_multiplier",
                        default=2,
                        help="if stage_limit is overridden - it would be stage_limit * multiplier (default: 2)")

    args = parser.parse_args()

    bytes_limit = from_byte_config(args.bytes_limit)
    if -1 == bytes_limit:
        print("Incorrect byte limit string:", args.bytes_limit)
        sys.exit(1)
    else:
        args.bytes_limit = bytes_limit

    stage_limits, default_limits = generate_stage_limits(
        args.bytes_file,
        args.messages_file,
        args.bytes_limit,
        int(args.messages_limit),
        float(args.min_override_percentage),
        float(args.override_multiplier)
    )

    time_period = {
        'start': args.time_start,
        'end': args.time_end
    }

    print_stage_limits(stage_limits, default_limits, time_period)
