# coding: utf-8

from __future__ import absolute_import, print_function

import logging
from collections import defaultdict

import click

from infra.rtc.iolimit_ticketer.cli import cli


def solve(request, capacity_list):
    for idx in range(len(capacity_list)):
        if capacity_list[idx] >= request:
            capacity_list[idx] -= request
            return True
    return False


class Storage:

    def __init__(self):
        self.capacity_list = []
        self.request_list = []
        self.allocated_request_list = []

    def add_capacity(self, capacity_bytes):
        self.capacity_list.append(capacity_bytes)

    def add_request(self, capacity_bytes, service, cluster, pod):
        self.request_list.append((capacity_bytes, service, cluster, pod))

    def find_pods_that_cant_fit(self):
        if not self.request_list:
            return
        assert self.capacity_list
        self.capacity_list.sort(reverse=True)
        self.request_list.sort(reverse=True)
        for request, service, cluster, pod in self.request_list:
            solved = solve(request, self.capacity_list)
            if not solved:
                yield service, cluster, pod, request
            else:
                self.allocated_request_list.append(request)


@cli.command('lvm_fit')
@click.pass_context
def lvm_fit(ctx):
    """Find which volumes can't survive LVM."""

    nodes = defaultdict(list)
    for service_stat in ctx.obj.yp_stat.service_map.values():
        for cluster_stat in service_stat.clusters.values():
            if cluster_stat.cluster_name in ("sas-test", "man-pre"):
                continue
            for pod_stat in cluster_stat.pods.values():
                if pod_stat.node_id:
                    nodes[(cluster_stat.cluster_name, pod_stat.node_id)].append((service_stat, cluster_stat, pod_stat))

    pods_left = []
    disk_capacity_map = defaultdict(list)
    allocated_request_map = defaultdict(list)
    for (cluster_name, node_id), pod_list in sorted(nodes.items()):
        if node_id not in ctx.obj.hm_stat.node_map:
            logging.warning("Unknown node %s", node_id)
            continue
        hm_descriptor = ctx.obj.hm_stat.node_map[node_id]
        storage_map = defaultdict(Storage)
        for disk in hm_descriptor.disks:
            if disk.storage_class not in ("ssd", "hdd"):
                continue
            capacity_bytes = int(disk.capacity_bytes * 0.9)
            storage_map[disk.storage_class].add_capacity(capacity_bytes)
            disk_capacity_map[(cluster_name, disk.storage_class)].append(capacity_bytes)
        for service_stat, cluster_stat, pod_stat in pod_list:
            for storage_class, storage in storage_map.items():
                capacity_request = sum(volume_desc.capacity for volume_desc in pod_stat.iter_volume_desc(storage_class))
                if capacity_request:
                    storage.add_request(capacity_request, service_stat, cluster_stat, pod_stat)
        for storage_class, storage in storage_map.items():
            for service_stat, cluster_stat, pod_stat, request in storage.find_pods_that_cant_fit():
                logging.warning("Not fit: %s/%s/%s/%s on %s/%s", service_stat.deploy_engine, service_stat.service_id,
                                cluster_stat.cluster_name, pod_stat.pod_id, node_id, storage_class)
                pods_left.append((request, storage_class, service_stat, cluster_stat, pod_stat))
            allocated_request_map[(cluster_name, storage_class)].extend(storage.allocated_request_list)

    logging.info(
        "First pass: pods found %d, services found %d, accounts found %d",
        len(pods_left),
        len({x.service_id for _, _, x, _, _ in pods_left}),
        len({x.account_id for _, _, x, _, _ in pods_left})
    )

    assert disk_capacity_map
    for capacity_list in disk_capacity_map.values():
        assert capacity_list
        capacity_list.sort(reverse=True)

    pods_left_second_pass = []
    pods_left.sort(reverse=True)
    for request, storage_class, service_stat, cluster_stat, pod_stat in pods_left:
        capacity_list = disk_capacity_map[(cluster_stat.cluster_name, storage_class)]
        solved = solve(request, capacity_list)
        if not solved:
            pods_left_second_pass.append((request, storage_class, service_stat, cluster_stat, pod_stat))

    assert allocated_request_map
    for (cluster_name, storage_class), request_list in allocated_request_map.items():
        assert request_list
        capacity_list = disk_capacity_map[(cluster_name, storage_class)]
        capacity_list.sort(reverse=True)
        request_list.sort(reverse=True)
        not_solved = 0
        for request in request_list:
            if not solve(request, capacity_list):
                not_solved += 1
        if not_solved:
            logging.error("Not solved %s/%s: %d", cluster_name, storage_class, not_solved)
        assert not_solved == 0

    for request, storage_class, service_stat, cluster_stat, pod_stat in pods_left_second_pass:
        logging.error("Too big: %s/%s/%s/%s, requested %s/%d", service_stat.deploy_engine, service_stat.service_id,
                      cluster_stat.cluster_name, pod_stat.pod_id, storage_class, request)

    logging.info(
        "Second pass: pods found %d, services found %d, accounts found %d",
        len(pods_left_second_pass),
        len({x.service_id for _, _, x, _, _ in pods_left_second_pass}),
        len({x.account_id for _, _, x, _, _ in pods_left_second_pass})
    )
