# coding: utf-8

from __future__ import absolute_import, print_function

import json
from collections import defaultdict

import click

from infra.rtc.iolimit_ticketer.cli import cli

DISK_MODEL = {
    "HDD": {
        "rr_iops_divisor": 131072,
        "rw_iops_divisor": 131072,
        "sr_bandwidth_factor": 1,
        "sw_bandwidth_factor": 1,
        "total_bandwidth": 157286400,
        "storage_class": "hdd"
    },
    "SSD": {
        "rr_iops_divisor": 8192,
        "rw_iops_divisor": 8192,
        "sr_bandwidth_factor": 1,
        "sw_bandwidth_factor": 0.5,
        "total_bandwidth": 524288000,
        "storage_class": "ssd"
    },
    "NVME": {
        "rr_iops_divisor": 4096,
        "rw_iops_divisor": 4096,
        "sr_bandwidth_factor": 1,
        "sw_bandwidth_factor": 1,
        "total_bandwidth": 2147483648,
        "storage_class": "ssd"
    }
}


def get_stupid_disk_bandwidth(disk, disk_map):
    if not disk.slaves:
        return DISK_MODEL[disk.disk_type]["total_bandwidth"]
    return sum(
        get_stupid_disk_bandwidth(disk_map[disk_name], disk_map)
        for disk_name in disk.slaves
        if disk_name in disk_map
    )


def disk_to_storage_class(disk):
    return DISK_MODEL[disk.disk_type]["storage_class"]


def dm_disk_to_storage_class(disk):
    return "hdd" if disk.storage_class == "hdd" else "ssd"


def dm_disk_bandwidth(disk):
    return DISK_MODEL[disk.storage_class.upper()]["total_bandwidth"]


def compute(dc, walle_stat, yp_stat, hm_stat, topology):
    capacity_list = defaultdict(list)
    capacity_needed = defaultdict(int)
    bandwidth_needed = defaultdict(int)
    capacity_usage = defaultdict(int)
    bandwidth_usage = defaultdict(int)
    for account in yp_stat.account_map.values():
        if account.get_abc_id() == 1975:
            continue

        cluster = account.clusters.get(dc, {}).get("default")
        if cluster is None:
            continue
        for storage_class, disk in cluster.disks.items():
            if disk.capacity_limits:
                capacity_list[storage_class].append((disk.capacity_limits, account))
            capacity_needed[storage_class] += disk.capacity_limits
            bandwidth_needed[storage_class] += disk.bandwidth_limits
            capacity_usage[storage_class] += disk.capacity_usage
            bandwidth_usage[storage_class] += disk.bandwidth_usage

    lvm_ratio = {
        "hdd": 0.5,
        "ssd": 0.5
    }
    capacity_available = defaultdict(int)
    bandwidth_available = defaultdict(int)
    lvm_capacity_available = defaultdict(int)
    lvm_bandwidth_available = defaultdict(int)
    for node in hm_stat.node_map.values():
        if node.cluster != dc:
            continue
        walle_host = walle_stat.host_map.get(node.fqdn)
        if walle_host is None:
            continue

        if walle_host.project == "yp-iss-{}".format(dc):
            disk_by_storage = defaultdict(list)
            for disk in node.disks:
                if disk.storage_class not in ("hdd", "ssd", "nvme"):
                    continue
                storage_class = dm_disk_to_storage_class(disk)
                disk_by_storage[storage_class].append(disk)

            node_topology = topology[node.fqdn]
            for storage_class, disks in disk_by_storage.items():
                if len(disks) > 1:
                    disks.sort(key=lambda x: x.capacity_bytes)
                    disk = disks[0]
                    capacity_available[storage_class] += disk.capacity_bytes
                    bandwidth_available[storage_class] += dm_disk_bandwidth(disk)
                    storage_topology = node_topology[storage_class]
                    storage_topology["capacity"] += disk.capacity_bytes - 50 * 1024 * 1024 * 1024
                    storage_topology["bandwidth"] += dm_disk_bandwidth(disk)
                    for disk in disks[1:]:
                        lvm_capacity_available[storage_class] += disk.capacity_bytes
                        lvm_bandwidth_available[storage_class] += dm_disk_bandwidth(disk)
                else:
                    disk = disks[0]
                    capacity_available[storage_class] += int(disk.capacity_bytes * (1 - lvm_ratio[storage_class]))
                    bandwidth_available[storage_class] += int(dm_disk_bandwidth(disk) * (1 - lvm_ratio[storage_class]))
                    lvm_capacity_available[storage_class] += int(disk.capacity_bytes * lvm_ratio[storage_class])
                    lvm_bandwidth_available[storage_class] += int(dm_disk_bandwidth(disk) * lvm_ratio[storage_class])

                    storage_topology = node_topology[storage_class]
                    storage_topology["capacity"] += disk.capacity_bytes * (1 - lvm_ratio[storage_class]) - 50 * 1024 * 1024 * 1024
                    storage_topology["bandwidth"] += dm_disk_bandwidth(disk) * (1 - lvm_ratio[storage_class])

        else:
            disk_map = {disk.name: disk for disk in node.oops_disks if disk.disk_type}
            for disk in node.oops_disks:
                if not disk.fs_size or not disk.disk_type:
                    continue
                storage_class = disk_to_storage_class(disk)
                capacity_available[storage_class] += disk.fs_size * 0.5
                bandwidth_available[storage_class] += get_stupid_disk_bandwidth(disk, disk_map) * 0.5

    print(dc)

    for storage_class, usage in capacity_usage.items():
        print(storage_class)
        print("total capacity usage, bytes", storage_class, usage)
        print("total capacity usage, %", storage_class, usage / float(capacity_available[storage_class]))
        print("total lvm capacity usage, %", storage_class, usage / float(lvm_capacity_available[storage_class]))

    for storage_class, needed in capacity_needed.items():
        print(storage_class)
        print("total capacity needed, bytes", storage_class, needed)
        print("total capacity available, bytes", storage_class, capacity_available[storage_class],
              capacity_available[storage_class] / float(capacity_available[storage_class] + lvm_capacity_available[storage_class]))
        print("total lvm capacity available, bytes", storage_class, lvm_capacity_available[storage_class],
              lvm_capacity_available[storage_class] / float(capacity_available[storage_class] + lvm_capacity_available[storage_class]))
        print("total capacity available, %", storage_class, needed / float(capacity_available[storage_class]))

    for storage_class, needed in bandwidth_needed.items():
        print(storage_class)
        print("total bandwidth needed, bytes", storage_class, needed)
        print("total bandwidth available, bytes", storage_class, bandwidth_available[storage_class],
              bandwidth_available[storage_class] / float(bandwidth_available[storage_class] + lvm_bandwidth_available[storage_class]))
        print("total lvm bandwidth available, bytes", storage_class, lvm_bandwidth_available[storage_class],
              lvm_bandwidth_available[storage_class] / float(bandwidth_available[storage_class] + lvm_bandwidth_available[storage_class]))
        print("total bandwidth available, %", storage_class, needed / float(bandwidth_available[storage_class]))

    for storage_class, capacities in capacity_list.items():
        print(storage_class)
        capacities.sort()
        top = capacities[int(len(capacities) * 0.95):]
        print("top", storage_class, len(top), [(account.get_abc_id(), capacity) for capacity, account in top])
        print("needed, %", storage_class, sum(x for x, _ in top) / float(capacity_needed[storage_class]))
        print("available, %", storage_class, sum(x for x, _ in top) / float(capacity_available[storage_class]))


@cli.command('lvm_shared')
@click.pass_context
def lvm_shared(ctx):
    """Check LVM + shared cluster topology."""
    hm_stat = ctx.obj.hm_stat
    walle_stat = ctx.obj.walle_stat
    yp_stat = ctx.obj.yp_stat
    # node -> storage class -> ...
    topology = defaultdict(lambda: defaultdict(lambda: {"capacity": 0, "bandwidth": 0}))
    for dc in ("iva", "myt", "sas", "man", "vla"):
        compute(dc, walle_stat, yp_stat, hm_stat, topology)
        print("")

    with open("topology.json", "w") as stream:
        json.dump(topology, stream, indent=4)
