# coding: utf-8

from __future__ import absolute_import, print_function

import copy
import time
import functools
import logging

from deepdiff import DeepDiff
import yt.yson as yson

import infra.rtc.iolimit_ticketer.yp_model as yp_model

MEGABYTE = 1024 * 1024


def _update_disk_spec(yp_client, pod_id, modify_spec):
    transaction_id = yp_client.start_transaction()
    object_type = "pod"
    path = "/spec/disk_volume_requests"
    resp = yp_client.get_object(
        object_type, pod_id, [path],
        options=dict(fetch_timestamps=True),
        enable_structured_response=True
    )
    disk_spec = resp["result"][0]["value"]
    timestamp = resp["result"][0]["timestamp"]
    disk_spec_new = modify_spec(disk_spec=disk_spec)
    if disk_spec_new:
        yp_client.update_object(
            object_type, pod_id,
            set_updates=[{"path": path, "value": disk_spec_new}],
            transaction_id=transaction_id,
            attribute_timestamp_prerequisites=[{"path": path, "timestamp": timestamp}]
        )
        yp_client.commit_transaction(transaction_id)
        return True
    else:
        yp_client.abort_transaction(transaction_id)
        return False


def update_pods(cluster_name, pods, sleep_time, dry_run=True, overwrite=False, use_limits=False, zero_limits=False):

    def is_relevant(volume_spec):
        labels = volume_spec.get("labels")
        if not labels:
            return False

        mount_path = labels.get("mount_path")
        if not mount_path:
            return False

        quota_policy = volume_spec.get("quota_policy")
        if not quota_policy:
            return False

        bandwidth_guarantee = quota_policy.get("bandwidth_guarantee")
        bandwidth_limit = quota_policy.get("bandwidth_limit")
        if overwrite or zero_limits:
            return True
        elif bandwidth_guarantee and bandwidth_limit:
            return False
        else:
            return True

    def modify_spec(pod, disk_spec):
        new_disk_spec = []
        computed_volume_map = {volume.mount_path: volume for volume in pod.volumes}
        for volume_spec in disk_spec:
            if not is_relevant(volume_spec):
                new_disk_spec.append(volume_spec)
                continue

            mount_path = volume_spec["labels"]["mount_path"]
            if mount_path not in computed_volume_map:
                logging.warning("Pod %s/%s from service %s has no volume %s", pod.deploy_engine, pod.pod_id, pod.service_id, mount_path)
                new_disk_spec.append(volume_spec)
                continue

            computed_volume = computed_volume_map[mount_path]
            if volume_spec.get("storage_class") != computed_volume.storage_class:
                logging.warning("Pod %s/%s from service %s has volume %s with different storage class", pod.deploy_engine, pod.pod_id, pod.service_id, mount_path)
                new_disk_spec.append(volume_spec)
                continue

            new_volume_spec = copy.deepcopy(volume_spec)
            quota_policy = new_volume_spec["quota_policy"]
            if not quota_policy.get("bandwidth_guarantee") or overwrite:
                quota_policy["bandwidth_guarantee"] = yson.YsonUint64(computed_volume.bandwidth_guarantee * MEGABYTE)
            if use_limits and (not quota_policy.get("bandwidth_limit") or overwrite):
                quota_policy["bandwidth_limit"] = yson.YsonUint64(computed_volume.bandwidth_limit * MEGABYTE)
            if use_limits and zero_limits and quota_policy.get("bandwidth_limit"):
                quota_policy["bandwidth_limit"] = yson.YsonUint64(0)

            new_disk_spec.append(new_volume_spec)

        assert len(disk_spec) == len(new_disk_spec)

        if new_disk_spec != disk_spec:
            message = {}
            for a, b in zip(disk_spec, new_disk_spec):
                delta = DeepDiff(a, b, verbose_level=2)
                if delta:
                    assert "dictionary_item_added" in delta or "values_changed" in delta
                    message[a["labels"]["mount_path"]] = delta

            assert message

            if dry_run:
                logging.info("Pod %s/%s/%s from service %s disk spec should be changed: %r", pod.deploy_engine, pod.pod_id, pod.rack, pod.service_id, message)
                return None
            else:
                logging.info("Pod %s/%s/%s from service %s disk spec changed: %r", pod.deploy_engine, pod.pod_id, pod.rack, pod.service_id, message)
                return new_disk_spec

        else:
            logging.info("Pod %s/%s/%s from service %s remain unchanged", pod.deploy_engine, pod.pod_id, pod.rack, pod.service_id)
            return None

    with yp_model.create_yp_client(yp_model.YP_ADDRESS_MAP[cluster_name]) as yp_client:
        prev_rack = None
        should_wait = False
        for pod in pods:
            if should_wait:
                logging.info("Waiting for next slot...")
                time.sleep(sleep_time)
            if prev_rack is not None and prev_rack != pod.rack:
                logging.info("Waiting before new rack...")
                time.sleep(sleep_time * 3)
            prev_rack = pod.rack
            try:
                should_wait = _update_disk_spec(yp_client, pod.pod_id, functools.partial(modify_spec, pod=pod))
            except Exception as err:
                logging.exception("Pod %s/%s/%s from service %s update failed: %s", pod.deploy_engine, pod.pod_id, pod.rack, pod.service_id, err)
                should_wait = True
