# coding: utf-8
from __future__ import division, print_function, unicode_literals

import abc
import six
import time
import functools

import yp.client as yp_c
import nanny_rpc_client as nanny_c
import google.protobuf.json_format as proto_json_format

from infra.nanny.yp_lite_api.proto import pod_sets_api_pb2
from infra.nanny.yp_lite_api.py_stubs import pod_sets_api_stub
from library.python.par_apply import par_apply

from .. import utils

CLUSTERS = ["sas", "man", "vla", "myt", "iva"]
YP_LITE_API = "http://yp-lite-ui.nanny.yandex-team.ru/api/yplite/pod-sets"


class YpNode(object):
    @classmethod
    def selectors(cls):
        return ["/meta/id"]

    def __init__(self, node_id, cluster):
        self.cluster = cluster
        self.node_id = node_id
        self.resources = []
        self.disk_reduce = 300 << 30  # Reduce all disks by this value
        self.rootfs_space = 100 << 30  # 100Gb
        self.rootfs_bandwidth = 80 << 20
        self.vol_limit = 10 << 40  # 10Tb
        self._pods = None
        self._scheduled_pods = None

    def __repr__(self):
        return "{} (empty: {}, cluster: {})".format(self.node_id, self.empty, self.cluster)

    def __hash__(self):
        return hash(self.node_id + self.cluster)

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()

    def __ne__(self, other):
        return self.__hash__() != other.__hash__()

    def pods(self):
        if self._pods:
            return self._pods
        result = set()
        for pods in (res.actual_allocations for res in self.resources):
            for pod in pods:
                result.add(pod)
        self._pods = result
        return result

    def scheduled_pods(self):
        if self._scheduled_pods:
            return self._scheduled_pods
        result = set()
        for pods in (res.scheduled_allocations for res in self.resources):
            for pod in pods:
                result.add(pod)
        result.difference_update(self.pods())  # Store only Unscheduled pods
        self._scheduled_pods = result
        return result

    def add(self, other):
        if Resource in other.__class__.__bases__:
            self.resources.append(other)
        else:
            raise ValueError(
                "Unsupported operation, got {} and {}".format(
                    type(self),
                    type(other)
                )
            )

    @property
    def empty(self):
        # Empty node - node with both empty allocated and scheduled pods
        return not self.pods() and not self.scheduled_pods()

    def print(self):
        string = "Node: {}\n".format(self.__repr__())
        for res in self.resources:
            string += "\t" + str(res) + "\n"
        for pod in self.pods():
            string += "\t" + str(pod) + "\n"
        for pod in self.scheduled_pods():
            string += "\tScheduled " + str(pod) + "\n"
        print(string)

    @property
    def free_cpu(self):
        for res in self.resources:
            if res.res_name == "cpu":
                return int(res.free)
        raise ValueError("CPU resource not found on node {}".format(self.node_id))

    @property
    def free_mem(self):
        for res in self.resources:
            if res.res_name == "memory":
                return int(res.free)
        raise ValueError("MEMORY resource not found on node {}".format(self.node_id))

    @property
    def fastest_storage(self):
        for res in [res for res in self.resources if res.res_name == "disk"]:
            if res.res_type == "ssd":
                return res.res_type
        return "hdd"

    def gen_volumes_list(self, fastest_mount_point):
        vol_id = 0
        fastest_used = False
        for disk in [res for res in self.resources if res.res_name == "disk"]:
            if disk.res_type == self.fastest_storage and not fastest_used:
                fastest_used = True
                vr = pod_sets_api_pb2.VolumeRequest(
                    disk_quota_megabytes=(min(
                        disk.free - self.rootfs_space,
                        self.vol_limit
                    ) - self.disk_reduce) >> 20,
                    storage_class=disk.res_type,
                    bandwidth_limit_megabytes_per_sec=disk.bandwidth_free >> 20,
                    mount_point=fastest_mount_point,
                    bandwidth_guarantee_megabytes_per_sec=(disk.bandwidth_free - self.rootfs_bandwidth) >> 20
                )
                yield vr
            else:
                subvols = disk.free // self.vol_limit
                if subvols <= 1:
                    vr = pod_sets_api_pb2.VolumeRequest(
                        disk_quota_megabytes=(min(
                            disk.free,
                            self.vol_limit
                        ) - self.disk_reduce) >> 20,
                        storage_class=disk.res_type,
                        bandwidth_limit_megabytes_per_sec=disk.bandwidth_free >> 20,
                        mount_point="/storage/{}".format(vol_id),
                        bandwidth_guarantee_megabytes_per_sec=disk.bandwidth_free >> 20
                    )
                    vol_id += 1
                    yield vr
                    continue
                for part in range(0, subvols):
                    vr = pod_sets_api_pb2.VolumeRequest(
                        disk_quota_megabytes=(self.vol_limit - self.disk_reduce) >> 20,
                        storage_class=disk.res_type,
                        bandwidth_limit_megabytes_per_sec=disk.bandwidth_free >> 20,
                        mount_point="/storage/{}".format(vol_id),
                        bandwidth_guarantee_megabytes_per_sec=(disk.bandwidth_free // subvols) >> 20
                    )
                    vol_id += 1
                    yield vr

    def create_pod_request(self, pod_set, network, fastest_mount_point, use_int_vcpu, enable_kvm):
        if not self.empty:
            return None
        pod_allocation_request = pod_sets_api_pb2.AllocationRequest(
            pod_naming_mode=pod_sets_api_pb2.AllocationRequest.PodNamingMode.ENUMERATE,
            replicas=1,
            vcpu_guarantee=int(self.free_cpu / 1000) * 1000 if use_int_vcpu else self.free_cpu,
            network_macro=network,
            snapshots_count=3,
            root_fs_quota_megabytes=self.rootfs_space >> 20,
            root_volume_storage_class=self.fastest_storage,
            root_bandwidth_guarantee_megabytes_per_sec=self.rootfs_bandwidth >> 20,
            root_bandwidth_limit_megabytes_per_sec=(self.rootfs_bandwidth >> 20) * 3,
            thread_limit=0,
            memory_guarantee_megabytes=self.free_mem >> 20,
            work_dir_quota_megabytes=1024,
            persistent_volumes=list(self.gen_volumes_list(fastest_mount_point)),
            labels=[pod_sets_api_pb2.KeyValuePair(key="is_alive", value="true")],
            host_devices=pod_sets_api_pb2.HostDevices(
                dev_kvm="Enable",
            ) if enable_kvm else None,
        )
        create_pod_request = pod_sets_api_pb2.CreatePodsRequest(
            service_id=pod_set,
            cluster=self.cluster.upper(),
            allocation_request=pod_allocation_request,
        )

        def volume_request_repr(volume_request):
            return "{} MB, {} MB/s, at {}".format(
                volume_request.disk_quota_megabytes,
                volume_request.bandwidth_limit_megabytes_per_sec,
                volume_request.mount_point
            )

        print("Generate allocation request. CPU:{} MEM:{} Rootfs:{} Volumes:{}".format(
            pod_allocation_request.vcpu_guarantee,
            pod_allocation_request.memory_guarantee_megabytes,
            pod_allocation_request.root_fs_quota_megabytes,
            [volume_request_repr(vr) for vr in pod_allocation_request.persistent_volumes]
        ))
        return create_pod_request


@six.add_metaclass(abc.ABCMeta)
class Resource(object):
    _main_resource_attribute = None
    node_id = "/meta/node_id"
    res_name = "/meta/kind"
    _r_free = "/status/free"
    _r_used = "/status/used"
    _storage_class = "/spec/disk/storage_class"
    _r_pods = "/status"

    @staticmethod
    def fields():
        return [
            "node_id",
            "res_name",
            "_r_free",
            "_r_used",
            "_storage_class",
            "_r_pods"
        ]

    @classmethod
    def selectors(cls):
        return [getattr(cls, field) for field in cls.fields()]

    @classmethod
    def create(cls, data):
        res_map = {
            "cpu": Cpu,
            "disk": Disk,
            "slot": Slot,
            "memory": Memory,
            "network": Network,
            "gpu": Gpu
        }
        try:
            return res_map.get(data[cls.fields().index("res_name")])(data)
        except TypeError:
            print("Can't create resource from data: ", str(data))
            raise

    def __init__(self, data):
        for field, value in zip(self.fields(), data):
            setattr(self, field, value)
        self._res_data = data
        self.res_type = (
            self._storage_class
            if isinstance(self._storage_class, six.string_types) else
            self.res_name
        )
        self.free = self._free()
        self.used = self._used()
        self.total = self._total()
        self.actual_allocations = list(self._actual_allocations())
        self.scheduled_allocations = list(self._scheduled_allocations())

    def __repr__(self):
        return "Resource {} ({}) Free/Used/Total: {}/{}/{}".format(
            self.res_type,
            self._main_res_attribute,
            self.free,
            self.used,
            self.total
        )

    def _free(self):
        return self._r_free.get(self.res_name, {}).get(self._main_res_attribute)

    def _used(self):
        return self._r_used.get(self.res_name, {}).get(self._main_res_attribute)

    def _total(self):
        return self.free + self.used

    def _actual_allocations(self):
        for allocation in self._r_pods.get("actual_allocations"):
            yield Pod(
                allocation.get("pod_id"),
                allocation.get("pod_uuid"),
                self.node_id
            )

    def _scheduled_allocations(self):
        for allocation in self._r_pods.get("scheduled_allocations"):
            yield Pod(
                allocation.get("pod_id"),
                allocation.get("pod_uuid"),
                self.node_id
            )


class Cpu(Resource):
    _main_res_attribute = "guaranteed_capacity"


class Gpu(Resource):
    _main_res_attribute = "guaranteed_capacity"


class Disk(Resource):
    _main_res_attribute = "capacity"  # bytes

    def __init__(self, data):
        self._additional_res_attribute = "bandwidth"
        super(Disk, self).__init__(data)
        self.bandwidth_free = self._r_free.get(self.res_name, {}).get(self._additional_res_attribute)
        self.bandwidth_used = self._r_used.get(self.res_name, {}).get(self._additional_res_attribute)
        self.bandwidth_total = self.bandwidth_used + self.bandwidth_free

    def __repr__(self):
        return super(Disk, self).__repr__() + " ({}) {}/{}/{}".format(
            self._additional_res_attribute,
            self.bandwidth_free,
            self.bandwidth_used,
            self.bandwidth_total
        )


class Memory(Resource):
    _main_res_attribute = "guaranteed_capacity"  # bytes


class Network(Resource):
    _main_res_attribute = "guaranteed_bandwidth"


class Slot(Resource):
    _main_res_attribute = "guaranteed_capacity"


class Pod(object):
    def __init__(self, pod_id, pod_uuid, node):
        self.id = pod_id
        self.uuid = pod_uuid
        self.node = node

    def __repr__(self):
        return "Pod {}".format(self.id)

    def __hash__(self):
        return hash(self.uuid)

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()

    def __ne__(self, other):
        return self.__hash__() != other.__hash__()


def get_segment_nodes(yp_segment, walle_project=None, clusters=None, add_dead=False):
    req_filter = ['[/labels/segment]="{}"'.format(yp_segment)]
    if walle_project:
        req_filter.append('[/labels/extras/walle/project] = "{}"'.format(walle_project))
    if not add_dead:
        req_filter.append('[/status/hfsm/state] = "up"')
    dcs = clusters if clusters else CLUSTERS
    nodes = []
    for dc in dcs:
        try:
            with yp_c.YpClient(dc, config={"token": utils.oauth_token()}) as c:
                output = c.select_objects(
                    "node",
                    selectors=YpNode.selectors(),
                    filter=" AND ".join(req_filter),
                    batching_options=yp_c.BatchingOptions()
                )
            nodes += [n[0] for n in output if n] if output else output
        except yp_c.YpClientError as e:
            print("Error with YP communicating: {}".format(e))
            raise
    return set(nodes)


def get_yp_resources(cluster, host):
    print("Getting YP resources for cluster {}, host {}".format(cluster, host))
    try:
        with yp_c.YpClient(cluster, config={"token": utils.oauth_token()}) as c:
            return c.select_objects(
                "resource",
                selectors=Resource.selectors(),
                batching_options=yp_c.BatchingOptions(),
                filter='[/meta/node_id] = "{}"'.format(host)
            )
    except yp_c.YpClientError as e:
        print("Error with YP communicating: {}".format(e))
        raise


def get_resources_in_segment(segment, clusters, walle_project=None, add_dead=None):
    segment_resources = {}
    for cluster in clusters:
        segment_nodes = get_segment_nodes(segment, walle_project=walle_project, clusters=[cluster], add_dead=add_dead)
        yp_resource_objects = par_apply(segment_nodes, lambda node: get_yp_resources(cluster, node), 5)
        for yp_resource_object in sum(yp_resource_objects, []):
            resource = Resource.create(yp_resource_object)
            segment_resources.setdefault(resource.node_id, YpNode(resource.node_id, cluster))
            segment_resources[resource.node_id].add(resource)
    return segment_resources


def list_yp_nodes(params):
    print("Getting list of {} nodes in segment '{}'".format("free/empty" if params.empty else "all", params.segment))
    for node in iter(get_resources_in_segment(
        params.segment,
        params.datacenter,
        walle_project=params.walle_project,
        add_dead=params.with_dead
    ).values()):
        if not params.empty or node.empty:
            node.print()
    return


def add_remove_label(params, action=None):
    # Set "is_alive" label to true or false, to remove or add pod to samogon
    pod_set = params.pod_set.replace("-", "_")
    if not action:
        action = b"true" if params.action == "add" else b"false"
    stub = get_nanny_client()
    for pod in params.pods:
        dc, pod_id = pod.split(":")
        existing_pod_request = pod_sets_api_pb2.GetPodRequest(
            pod_id=pod_id,
            cluster=dc.upper()
        )
        existing_pod = stub.get_pod(existing_pod_request)
        existing_labels = existing_pod.pod.labels.attributes
        new_labels = set()
        version = None
        try:
            for label in existing_labels:
                if label.key == "is_alive":
                    label.value = action
                if label.key == "nanny_version":
                    version = label.value
                new_labels.add((label.key, label.value))
        except AttributeError:
            new_labels.add(("is_alive", action))
            new_labels.add(("nanny_service_id", pod_set))

        if ("is_alive", action) not in new_labels:
            new_labels.add(("is_alive", action))

        pod_update_request = pod_sets_api_pb2.UpdatePodRequest(
            labels=new_labels,
            pod_id=pod_id,
            version=version,
            cluster=dc.upper()
        )
        print("Set is_alive label to {} for {} in {}".format(action, pod_id, dc))
        stub.update_pod(pod_update_request)
        if params.action == "remove":
            print(
                "To remove pod:\n",
                "Wait until samogon will generate new snapshot",
                "here: https://nanny.yandex-team.ru/ui/#/services/catalog/{}\n".format(pod_set),
                "After that you should remove old snapshot, and remove pod {} here:".format(pod_id),
                "https://nanny.yandex-team.ru/ui/#/services/catalog/{}/yp_pods/".format(pod_set)
            )


def get_nanny_client():
    nanny_client = nanny_c.RetryingRpcClient(
        rpc_url=YP_LITE_API,
        oauth_token=utils.oauth_token(),
        request_timeout=30,
    )
    return pod_sets_api_stub.YpLiteUIPodSetsServiceStub(nanny_client)


def allocate_pods(params):
    print("Get all YP resources and node list in segment {}".format(params.segment))
    free_nodes = [
        node for node in iter(get_resources_in_segment(
            params.segment,
            params.datacenter,
            walle_project=params.walle_project
        ).values()) if node.empty
    ]
    if not free_nodes:
        print("No free nodes to allocate in segment {}".format(params.segment))
        return

    free_nodes_count = len(free_nodes)

    if params.count == "all":
        print("Will allocate {} nodes".format(len(free_nodes)))
        params.count = free_nodes_count
    else:
        params.count = int(params.count)

    if free_nodes_count < params.count:
        print("Can't serve request for allocation {} nodes. Free nodes in segment {}: {}".format(
            params.count, params.segment, free_nodes_count
        ))
        return
    stub = get_nanny_client()
    new_pods = []
    print("Generate and send allocation requests to nanny yp-lite api")
    for node in free_nodes[:params.count]:
        pod_request = node.create_pod_request(params.pod_set.replace("-", "_"), params.network, params.mount_point, params.use_int_vcpu, params.enable_kvm)
        response = stub.create_pods(pod_request)
        if len(response.pod_ids) != 1:
            print(
                "Can't create pod for service {} on cluster {}. Response:\n {}".format(
                    params.pod_set.replace("-", "_"), node.cluster,
                    proto_json_format.MessageToJson(response)
                )
            )
        new_pods.append((node.cluster.upper(), response.pod_ids[0]))
        print("Allocation request accepted for pod {} on cluster {}".format(response.pod_ids[0], node.cluster))

    print("Check for request status")
    while new_pods:
        print("{} pod request(s) not assigned.".format(len(new_pods)))
        for cluster, pod in new_pods:
            get_pod_request = pod_sets_api_pb2.GetPodRequest(cluster=cluster, pod_id=pod)
            response = stub.get_pod(get_pod_request)
            pod_scheduling = proto_json_format.MessageToDict(response.pod.status.scheduling)
            if pod_scheduling.get("state") == "SS_ASSIGNED":
                print("Pod {} assigned".format(pod))
                new_pods.remove((cluster, pod))
        time.sleep(10 if new_pods else 0)
    return


def setup_parser(parser):
    dc_validate = functools.partial(utils.list_parse, allowed=set(CLUSTERS))
    subparsers = parser.add_subparsers(help="sub-command help")

    list_nodes_parser = subparsers.add_parser("list", help="List nodes with resources in YP-segment")
    list_nodes_parser.add_argument("-a", "--with-dead", action="store_true", help="Add dead host")
    list_nodes_parser.add_argument("-e", "--empty", action="store_true", help="List only free hosts in segment")
    list_nodes_parser.add_argument("-s", "--segment", default="sandbox", help="YP segment with target Nodes")
    list_nodes_parser.add_argument("-w", "--walle-project", default=None, help="Filter Nodes by Wall-e project")
    list_nodes_parser.add_argument(
        "-d",
        "--datacenter",
        default=CLUSTERS,
        type=dc_validate,
        help="Datacenter(s) to deal with. Valid = {}".format(", ".join(CLUSTERS))
    )
    list_nodes_parser.set_defaults(func=list_yp_nodes)

    allocate_pods_parser = subparsers.add_parser("allocate", help="Allocate pods in pod_set")
    allocate_pods_parser.add_argument("-p", "--pod-set", default="bootstrap_sandbox3", help="PodSet for pods creation")
    allocate_pods_parser.add_argument("-n", "--network", default="_CMSEARCHNETS_", help="Network for created pods")
    allocate_pods_parser.add_argument(
        "-c",
        "--count",
        help="How many pods should i allocate. int or \"all\"",
        required=True
    )
    allocate_pods_parser.add_argument("-s", "--segment", default="sandbox", help="YP segment with target Nodes")
    allocate_pods_parser.add_argument("-w", "--walle-project", default=None, help="Filter Nodes by Wall-e project")
    allocate_pods_parser.add_argument(
        "-d",
        "--datacenter",
        default=CLUSTERS,
        type=dc_validate,
        help="Datacenter(s) to deal with. Valid = {}".format(", ".join(CLUSTERS))
    )
    allocate_pods_parser.add_argument("-m", "--mount-point", default="/place", help="Path to mount point of the volume located on the fastes disk")
    allocate_pods_parser.add_argument("-i", "--use-int-vcpu", default=False, action="store_true", help="Allocate pods with integer vcpu number")
    allocate_pods_parser.add_argument("--enable-kvm", default=False, action="store_true", help="Enable kvm device for allocated pods")
    allocate_pods_parser.set_defaults(func=allocate_pods)

    manage_label_parser = subparsers.add_parser(
        "samogon",
        help="Add or remove pod from samogon by setting 'is_alive' pod label via yp-lite api"
    )
    manage_label_parser.add_argument(
        "--action",
        choices=["add", "remove"],
        required=True,
        help="Add (is_alive=true) or remove (is_alive=false) pod from samogon."
    )
    manage_label_parser.add_argument("-p", "--pod-set", default="bootstrap_sandbox3", help="PodSet containing pods")
    manage_label_parser.add_argument(
        "--pods",
        nargs="+",
        required=True,
        help="Space separated list of dc:pod_id to manage. Example sas:bootstrap-sandbox3-1"
    )
    manage_label_parser.set_defaults(func=add_remove_label)
