import uuid

import operator
import yp.data_model as data_model
from yp_proto.yp.client.api.proto.autogen_pb2 import TResource
from google.protobuf.json_format import MessageToDict
from sepelib.core import config

from yp.client import to_proto_enum
from infra.swatlib.gutils import idle_iter
from yt import yson
from collections import defaultdict


from infra.qyp.vmproxy.src import errors
from infra.qyp.vmproxy.src.lib import abc_roles
from infra.qyp.vmproxy.src.lib.yp import yputil
from infra.qyp.vmproxy.src.lib.yp import yp_client as vmproxy_yp_client
from infra.qyp.vmproxy.src.web import error_utils
from infra.qyp.proto_lib import vmset_pb2

DEFAULT_YP_QUERY = '[/labels/deploy_engine] = "QYP"'
TENTACLES_CPU_GUARANTEE = 100


def quoted(items):
    for item in items:
        yield '"{}"'.format(item)


def make_yp_query(query=None, meta_ids=None):
    """
    :type query: vmset_pb2.YpVmFindQuery
    :type meta_ids: [str]
    :rtype: str
    """
    yp_query = DEFAULT_YP_QUERY
    if query and query.name:
        yp_query += ' AND is_substr("{}", [/meta/id])'.format(query.name)
    if query and query.account:
        yp_query += ' AND [/spec/account_id] in ({})'.format(','.join(quoted(query.account)))
    if query and query.segment:
        yp_query += ' AND [/spec/node_segment_id] in ({})'.format(','.join(quoted(query.segment)))
    if query and query.labels_filters:
        for label_name, query_builder in query.labels_filters.items():
            if query_builder.equal:
                yp_query += ' AND [/labels/{}] = "{}"'.format(label_name, query_builder.equal)
            elif query_builder.in_values:
                yp_query += ' AND try_get_string([/labels/{}], "") in ({})'.format(label_name, ','.join(
                    quoted(query_builder.in_values)))
            elif query_builder.is_substr:
                yp_query += ' AND is_substr("{}", [/labels/{}])'.format(query_builder.is_substr, label_name)
            elif query_builder.is_null:
                yp_query += ' AND is_null([/labels/{}])'.format(label_name)
            elif query_builder.is_not_null:
                yp_query += ' AND NOT is_null([/labels/{}])'.format(label_name)

    if meta_ids:
        yp_query += ' AND [/meta/id] in ({})'.format(','.join(quoted(meta_ids)))

    return yp_query


def multikeysort(items, fields):
    comparers = []
    for field in fields:
        order = 1
        if field.startswith('+'):
            order = 1
            field = field[1:]
        elif field.startswith('-'):
            order = -1
            field = field[1:]
        comparers.append((operator.attrgetter(field), order))

    def comparer(left, right):
        for func, order in comparers:
            result = cmp(func(left), func(right))
            if result:
                return order * result
        else:
            return 0

    return sorted(items, cmp=comparer)


TYPE_MODEL_MAP = {
    data_model.OT_POD: data_model.TPod,
    data_model.OT_POD_SET: data_model.TPodSet,
    data_model.OT_RESOURCE: TResource
}


class PodController(object):
    def __init__(self, yp_cluster, yp_client, sec_policy):
        """
        :type yp_cluster: str
        :type yp_client: infra.qyp.vmproxy.src.lib.yp.yp_client.YpClient
        :type sec_policy: infra.qyp.vmproxy.src.security_policy.SecurityPolicy
        """
        self.yp_cluster = yp_cluster
        self.yp_client = yp_client
        self.sec_policy = sec_policy

    def create_pod_with_pod_set(self, pod, pod_set):
        """
        :type pod: data_model.TPod
        :type pod_set: data_model.TPodSet
        """
        self.yp_client.create_pod_with_pod_set(pod_set, pod)

    def delete_pod_with_pod_set(self, object_id):
        """
        :type object_id: str
        """
        self.yp_client.remove_pod_set(object_id)

    def get_keys_by_logins(self, logins):
        """
        :type logins: collections.Iterable[unicode]
        :rtype: object_service_pb2.TRspSelectObjects
        """
        yp_query = '[/meta/id] in (\"{}\")'.format('\",\"'.join(list(logins)))
        selectors = ['/labels/staff/keys']
        request = self.yp_client.select_objects(yp_query, data_model.OT_USER, selectors)
        return [[yson.loads(item) for item in keys.values][0] for keys in request.results]

    def get_object(self, object_id, object_type, timestamp=None):
        """
        :type object_id: str
        :type object_type: int
        :type timestamp: int | NoneType
        :rtype: data_model.TPod
        """
        rsp = self.yp_client.get_object(object_id, object_type, timestamp)
        return yputil.loads_proto(rsp.result.values[0], TYPE_MODEL_MAP[object_type])

    def get_pod(self, pod_id, timestamp=None):
        """
        :type pod_id: str
        :type timestamp: int | NoneType
        :rtype: data_model.TPod
        """
        return self.get_object(pod_id, data_model.OT_POD, timestamp)

    def get_pod_set(self, pod_set_id, timestamp=None):
        """
        :type pod_set_id: str
        :type timestamp: int | NoneType
        :rtype: data_model.TPodSet
        """
        return self.get_object(pod_set_id, data_model.OT_POD_SET, timestamp)

    def get_resource(self, resource_id, timestamp=None):
        """
        :type resource_id: str
        :type timestamp: int | NoneType
        :rtype: data_model.TResource
        """
        return self.get_object(resource_id, data_model.OT_RESOURCE, timestamp)

    def get_active_pod(self, pod_id):
        """
        :type pod_id: str
        :rtype: data_model.TPod
        """
        pod = self.get_pod(pod_id)
        if pod.status.scheduling.error.code:
            error_dict = MessageToDict(
                message=pod.status.scheduling.error,
                including_default_value_fields=False,
                preserving_proto_field_name=True
            )
            raise errors.WrongStateError(error_utils.yp_error_dict_to_str(error_dict))
        if pod.status.scheduling.state != to_proto_enum(data_model.ESchedulingState, 'assigned').number:
            raise errors.WrongStateError('Yp pod has not been scheduled yet')
        current_state = self._get_current_state(pod)
        if current_state != 'ACTIVE':
            raise errors.WrongStateError('Agent has not been started yet. Current status: {}'.format(current_state))
        if not pod.status.ip6_address_allocations:
            raise errors.WrongStateError('Ip addresses for pod have not been allocated yet')
        return pod

    def get_allocated_node_ids(self, query=None):
        """
        :type query: str
        :rtype: list[str]
        """
        selectors = ['/status/scheduling/node_id']
        request = self.yp_client.list_pods(query, selectors=selectors)
        return [[yson.loads(val) for val in value.values][0] for value in request.results]

    def list_pods(self, query=None, sort=None, skip=0, limit=0):
        """
        :type query: vmset_pb2.YpVmFindQuery
        :type sort: vmset_pb2.YpVmFindSort
        :type skip: int
        :type limit: int
        :rtype: list[vmset_pb2.Vm]
        """
        ts = self.yp_client.generate_timestamp()
        user_access_pod_set_ids = []
        if query and query.login:
            user_access_pod_set_ids = self.yp_client.get_user_access_allowed_to(
                user=query.login,
                object_type=data_model.OT_POD_SET,
                permission=data_model.ACA_GET_QYP_VM_STATUS
            )
            if not user_access_pod_set_ids:
                return []

        pod_set_items = self._get_pod_set_items(
            query=make_yp_query(query, meta_ids=user_access_pod_set_ids),
            timestamp=ts
        )
        if not pod_set_items:
            return []

        pod_set_ids = [ps.meta.id for ps in pod_set_items]
        pods_dict = self._get_pod_dict(pod_set_ids, ts)
        vms = []
        for pod_set_pb in idle_iter(pod_set_items):
            pod_yson = pods_dict.get(pod_set_pb.meta.id)
            if pod_yson is None:
                continue
            pod_pb = self._cast_yson_to_pod(pod_yson)
            vm = yputil.cast_pod_to_vm(pod_pb, pod_set_pb)
            vms.append(vm)
        if sort and sort.field:
            vms = multikeysort(vms, sort.field)

        return vms[skip:skip + (limit or len(vms))] if (limit or skip) else vms

    def list_all_pods(self, query=None, skip=0, limit=0):
        ts = self.yp_client.generate_timestamp()

        pod_set_items = self._get_pod_set_items(query=make_yp_query(query=query), timestamp=ts, skip=skip, limit=limit)

        if not pod_set_items:
            return []

        pod_set_ids = (ps.meta.id for ps in pod_set_items)

        pod_dict = self._get_pod_dict(pod_set_ids, ts)

        vms = []
        for pod_set_pb in idle_iter(pod_set_items):
            pod_yson = pod_dict.get(pod_set_pb.meta.id)
            if pod_yson is None:
                continue
            pod_pb = self._cast_yson_to_pod(pod_yson)
            vm = yputil.cast_pod_to_vm(pod_pb, pod_set_pb)
            vms.append(vm)

        return vms

    def _cast_yson_to_pod(self, obj):
        pod = data_model.TPod(
            meta=yputil.loads_proto(obj[0], data_model.TPodMeta),
            spec=yputil.loads_proto(obj[1], data_model.TPodSpec),
            status=yputil.loads_proto(obj[2], data_model.TPodStatus),
        )
        yputil.cast_dict_to_attr_dict(yson.loads(obj[3]), pod.labels)
        annotations = {
            'owners': yson.loads(obj[4]),
            'backup_list': yson.loads(obj[5]),
        }
        yputil.cast_dict_to_attr_dict(annotations, pod.annotations)
        return pod

    def _cast_pod_yson_to_owners_dict(self, obj):
        return yson.loads(obj[4])

    def _cast_yson_to_pod_set(self, obj):
        pod_set = data_model.TPodSet(
            meta=yputil.loads_proto(obj[0], data_model.TPodSetMeta),
            spec=yputil.loads_proto(obj[1], data_model.TPodSetSpec)
        )
        yputil.cast_dict_to_attr_dict(yson.loads(obj[2]), pod_set.labels)
        return pod_set

    def _get_pod_set_items(self, query, timestamp, skip=0, limit=0):
        """
        :type query: str
        :type timestamp: int
        :rtype: list[data_model.TPodSet]
        """
        selectors = [
            '/meta',
            '/spec',
            '/labels',
        ]
        rsp = self.yp_client.list_pod_sets(query=query, timestamp=timestamp, selectors=selectors, offset=skip,
                                           limit=limit)

        return [self._cast_yson_to_pod_set(item.values) for item in idle_iter(rsp.results)]

    def _get_pod_dict(self, ids, ts):
        """
        :type ids: list[str]
        :rtype: dict[str, data_model.TPodSet]
        """
        selectors = [
            '/meta/id',
            '/meta',
            '/spec',
            '/status',
            '/labels',
            '/annotations/owners',
            '/annotations/backup_list',
            '/annotations/qyp_vm_spec',
        ]

        rsp = self.yp_client.get_objects(ids, data_model.OT_POD, selectors=selectors, timestamp=ts)
        return {yson.loads(item.result.values[0]): item.result.values[1:] for item in rsp.subresponses}

    def update_object(self, object_id, object_type, version, set_updates, t_id=None, ts=None):
        """
        :type object_id: str
        :type object_type: int
        :type version: str
        :type set_updates: dict[str, str]
        :type t_id: str | NoneType
        :type ts: int | NoneType
        """
        if not set_updates:
            # Nothing to update
            return
        # check that version corresponds
        pod_ = self.get_object(object_id, object_type, ts)
        cur_version = yputil.cast_attr_dict_to_dict(pod_.labels)['version']
        if version != cur_version:
            raise errors.ConcurrentModificationError('Concurrent modification detected, try again later')
        # Bump version
        updates = set_updates.copy()
        updates['/labels/version'] = yson.dumps(str(uuid.uuid4()))
        # make update
        self.yp_client.update_object(object_id, object_type, updates, t_id)

    def update_pod(self, pod_id, version, set_updates, t_id=None, ts=None):
        """
        :type pod_id: str
        :type version: str
        :type set_updates: dict[str, str]
        :type t_id: str | NoneType
        :type ts: int | NoneType
        """
        self.update_object(pod_id, data_model.OT_POD, version, set_updates, t_id, ts)

    def update_pod_set(self, pod_set_id, version, set_updates, t_id=None, ts=None):
        """
        :type pod_set_id: str
        :type version: str
        :type set_updates: dict[str, str]
        :type t_id: str | NoneType
        :type ts: int | NoneType
        """
        self.update_object(pod_set_id, data_model.OT_POD_SET, version, set_updates, t_id, ts)

    @staticmethod
    def get_pod_container_ip(pod):
        """
        :type pod: data_model.TPod
        :rtype: str
        """
        addrs = pod.status.ip6_address_allocations
        for addr in addrs:
            if addr.vlan_id == 'backbone' and not addr.persistent_fqdn:
                return addr.address
        raise ValueError('Backbone address not found for pod')

    @staticmethod
    def get_pod_owners(pod):
        """
        :type pod: data_model.TPod
        :rtype: dict
        """
        for item in pod.annotations.attributes:
            if item.key == 'owners':
                return yson.loads(item.value)
        raise ValueError('Owners not set for pod {}'.format(pod.meta.id))

    @staticmethod
    def _get_current_state(pod):
        """
        :type pod: data_model.TPod
        :rtype: str
        """
        conf_id = pod.spec.iss.instances[0].id.configuration.groupStateFingerprint
        states = pod.status.agent.iss.currentStates or []
        for state in states:
            if state.workloadId.configuration.groupStateFingerprint == conf_id:
                return state.currentState
        return 'NOT_CREATED'

    def check_read_permission(self, pod_id, subject_id):
        """
        :type pod_id: str
        :type subject_id: str
        :rtype: bool
        """
        return self.sec_policy.check_read_permission(pod_id, subject_id, self.yp_client)

    def check_write_permission(self, pod_id, subject_id):
        """
        :type pod_id: str
        :type subject_id: str
        :rtype: bool
        """
        return self.sec_policy.check_write_permission(pod_id, subject_id, self.yp_client)

    def check_use_macro_permission(self, macro_name, subject_id):
        """
        :type macro_name: str
        :type subject_id: str
        :rtype: bool
        """
        if macro_name in config.get_value('vmproxy.network_whitelist', []):
            return True
        return self.yp_client.check_object_permissions(
            object_id=macro_name,
            object_type=data_model.OT_NETWORK_PROJECT,
            subject_id=subject_id,
            permission=[data_model.ACA_USE]
        )

    def check_use_account_permission(self, acc_id, subject_id, use_cache):
        """
        :type acc_id: str
        :type subject_id: str
        :type use_cache: bool
        :rtype: bool
        """
        if subject_id == yputil.ROBOT_LOGIN:
            return True
        if acc_id == 'tmp':
            return True
        has_permission = self.yp_client.check_object_permissions(
            object_id=acc_id,
            object_type=data_model.OT_ACCOUNT,
            subject_id=subject_id,
            permission=[data_model.ACA_USE]
        )
        if not has_permission:
            return False
        return acc_id in abc_roles.filter_accounts_by_roles([acc_id], subject_id, use_cache)

    def start_transaction(self):
        """
        :rtype: (str, int)
        """
        t_id, ts = self.yp_client.start_transaction()
        return t_id, ts

    def commit_transaction(self, t_id):
        """
        :type t_id: str
        """
        self.yp_client.commit_transaction(t_id)

    @staticmethod
    def get_backup_list(pod):
        """
        :type pod: data_model.TPod
        :rtype: list[yson.yson_types.YsonString]
        """
        annotations = yputil.cast_attr_dict_to_dict(pod.annotations)
        return annotations.get('backup_list') or []

    def update_backup_list(self, pod, backup_list):
        """
        :param pod: data_model.TPod
        :type backup_list: list[yson.yson_types.YsonString]
        """
        set_updates = {
            '/annotations/backup_list': yson.dumps(backup_list)
        }
        version = yputil.cast_attr_dict_to_dict(pod.labels)['version']
        self.update_pod(pod.meta.id, version, set_updates)

    def list_accounts(self, query):
        """
        :type query: str
        :rtype: list[data_model.TAccount]
        """
        rsp = self.yp_client.list_accounts(query)
        accounts = []
        for result in rsp.results:
            accounts.append(yputil.loads_proto(result.values[0], data_model.TAccount))
        return accounts

    def get_accounts(self, ids, selectors=None):
        """
        :type ids: list[str]
        :type selectors: list[str]
        :rtype: list[data_model.TAccount]
        """
        rsp = self.yp_client.get_accounts(ids, selectors=selectors)
        accounts = []
        for result in rsp.subresponses:
            if result.result.values and len(result.result.values) > 0:
                accounts.append(yputil.loads_proto(result.result.values[0], data_model.TAccount))
        return accounts

    def get_nodes_spec(self, query):
        """
        :type query: str
        :rtype: dict[str, dict]
        """
        nodes = {}
        selectors = [
            '/meta/id',
            '/spec'
        ]
        rsp = self.yp_client.list_nodes(query, selectors=selectors)
        for result in rsp.results:
            node_id, spec = [yson.loads(val) for val in result.values]
            nodes[node_id] = spec
        return nodes

    def list_resources_by_nodes(self, node_ids):
        """
        :type node_ids: list[str]
        :rtype: list
        """
        selectors = [
            '/meta/node_id',
            '/spec',
            '/status/free'
        ]
        if not node_ids:
            return {}

        query = '[/meta/node_id] in ({})'.format(','.join(quoted(node_ids)))
        rsp = self.yp_client.list_resources(query, selectors=selectors)
        return rsp.results

    def forced_node_free(self, node_id):
        """
        :type node_id: str
        :rtype: bool
        """
        resource_id = '-'.join(['cpu', node_id.replace('.', '-')])
        resource = self.get_resource(resource_id)
        allocations = resource.status.scheduled_allocations
        # QEMUKVM-1514
        tentacles_only = len(allocations) == 1 and allocations[0].cpu.capacity <= TENTACLES_CPU_GUARANTEE
        return not allocations or tentacles_only

    def get_free_network_ids(self):
        """
        :rtype: frozenset(list)
        """
        rsp = self.yp_client.select_objects(None, data_model.OT_INTERNET_ADDRESS, selectors=['/spec/network_module_id', '/status'])
        nerwork_ids = set()
        for result in rsp.results:
            network_id, status = [yson.loads(val) for val in result.values]
            if not status:
                nerwork_ids.add(network_id)
        return frozenset(nerwork_ids)

    def list_account_ids_by_login(self, login):
        """
        :type login: str
        :rtype: list[str]
        """
        return self.yp_client.get_user_access_allowed_to(login, data_model.OT_ACCOUNT, data_model.ACA_USE)

    def get_pod_stats(self, account=None, segment=None):
        """
        :type account: [str]
        :type segment: [str]
        :rtype: (int, vmset_pb2.ResourceInfo, dict)
        """
        pod_query = DEFAULT_YP_QUERY
        selectors = [
            '/spec/resource_requests/vcpu_guarantee',
            '/spec/resource_requests/memory_guarantee',
            '/spec/disk_volume_requests',
            '/labels/vmagent_version',
        ]
        usage = vmset_pb2.ResourceInfo()
        if account or segment:
            query_pb = vmset_pb2.YpVmFindQuery(account=account, segment=segment)
            another_selectors = ['/meta/id']
            resp = self.yp_client.list_pod_sets(query=make_yp_query(query=query_pb), selectors=another_selectors)
            pod_set_ids = [yson.loads(item.values[0]) for item in resp.results]
            if not pod_set_ids:
                return 0, usage, {}
            rsp = self.yp_client.get_objects(pod_set_ids, data_model.OT_POD, selectors=selectors)

            vmagent_versions_stats = defaultdict(lambda: 0)
            for item in idle_iter(rsp.subresponses):
                usage.cpu += yson.loads(item.result.values[0])
                usage.mem += yson.loads(item.result.values[1])
                for volume in yson.loads(item.result.values[2]):
                    if volume.get('labels', {}).get('mount_path', '') != '/':
                        storage_class = volume.get('storage_class')
                        usage.disk_per_storage[storage_class] += volume['quota_policy']['capacity']
                vmagent_version = yson.loads(item.result.values[3])
                if isinstance(vmagent_version, (bytes, unicode)):
                    vmagent_versions_stats[vmagent_version] += 1
                else:
                    vmagent_versions_stats['N/A'] += 1

            return len(rsp.subresponses), usage, vmagent_versions_stats

        rsp = self.yp_client.list_pods(query=pod_query, selectors=selectors)
        vmagent_versions_stats = defaultdict(lambda: 0)
        for item in idle_iter(rsp.results):
            usage.cpu += yson.loads(item.values[0])
            usage.mem += yson.loads(item.values[1])
            for volume in yson.loads(item.values[2]):
                if volume.get('labels', {}).get('mount_path', '') != '/':
                    storage_class = volume.get('storage_class')
                    usage.disk_per_storage[storage_class] += volume['quota_policy']['capacity']
            vmagent_version = yson.loads(item.values[3])
            if isinstance(vmagent_version, (bytes, unicode)):
                vmagent_versions_stats[vmagent_version] += 1
            else:
                vmagent_versions_stats['N/A'] += 1

        return len(rsp.results), usage, vmagent_versions_stats

    def update_pod_with_move(self, pod_id, pod_version, pod_set_version, pod, pod_set_updates):
        """
        :type pod_id: str
        :type pod_version: str
        :type pod_set_version: str
        :type pod: data_model.TPod
        :type pod_set_updates: dict[str, str]
        """
        t_id, ts = self.start_transaction()
        pod_ = self.get_object(pod_id, data_model.OT_POD, ts)
        cur_version = yputil.cast_attr_dict_to_dict(pod_.labels)['version']
        if pod_version != cur_version:
            raise errors.ConcurrentModificationError('Concurrent modification detected, try again later')
        self.yp_client.remove_object(pod_id, data_model.OT_POD, t_id)
        self.yp_client.create_pod(pod, t_id)
        self.update_pod_set(pod_id, pod_set_version, pod_set_updates, t_id, ts)
        self.commit_transaction(t_id)

    def update_pod_with_acknowledge_eviction(self, pod_id, pod_version, pod_updates, pod_set_version, pod_set_updates):
        pod_updates['/control/acknowledge_eviction'] = yson.dumps({'message': 'Acknowledged by evacuate'})
        if pod_set_updates:
            t_id, ts = self.start_transaction()
            self.update_pod_set(pod_id, pod_set_version, pod_set_updates, t_id, ts)
            self.update_pod(pod_id, pod_version, pod_updates, t_id, ts)
            self.commit_transaction(t_id)
        else:
            self.update_pod(pod_id, pod_version, pod_updates)

    def get_alive_pods(self, pod_ids):
        """
        :type pod_ids: list[str]
        :rtype: set[str]
        """
        rsp = self.yp_client.get_objects(pod_ids, data_model.OT_POD, selectors=['/meta/id'])
        result = set()
        for sub in rsp.subresponses:
            if sub.result.values:
                result.add(yson.loads(sub.result.values[0]))
        return result

    def get_all_daemon_set_pod_set_ids(self):
        """
        :rtype: set[str]
        """
        rsp = self.yp_client.select_objects(None, data_model.OT_DAEMON_SET, selectors=['/meta/pod_set_id'])
        pod_set_ids = set()
        for result in rsp.results:
            pod_set_ids.add(yson.loads(result.values[0]))
        return pod_set_ids

    def get_service_existing_scopes(self, service_id):
        """
        :type service_id: str
        :rtype: list[str]
        """
        yp_group_ids = ['abc:service-scope:{}:1'.format(service_id), 'abc:service-scope:{}:8'.format(service_id)]
        yp_rsp = self.yp_client.get_objects(yp_group_ids, data_model.OT_GROUP, selectors=['/labels/abc/scope_slug'])
        result = []
        for sub in yp_rsp.subresponses:
            if sub.result.values:
                result.append(yson.loads(sub.result.values[0]))
        return result


class PodControllerFactory(object):
    def __init__(self, cluster, sec_policy):
        self.cluster = cluster
        self.sec_policy = sec_policy

    def get_object(self, cluster=None):
        """
        :type cluster: str
        :rtype: PodController
        """
        cluster = cluster or self.cluster
        yp_client = vmproxy_yp_client.make_yp_client(cluster)
        return PodController(
            yp_cluster=cluster,
            yp_client=yp_client,
            sec_policy=self.sec_policy,
        )
