from __future__ import unicode_literals

from gevent import threadpool

import yp.client
import yp.common
import yt.yson as yson
import yt_yson_bindings
import yp.data_model as data_model
from yp_proto.yp.client.api.proto import object_service_pb2
from infra.swatlib.gevent import geventutil
from infra.mc_rsc.src.consts import DEFAULT_OBJECT_SELECTORS, DELEGATE_REMOVING_LABEL
from infra.mc_rsc.src.lib.loaders import make_attr_paths


class YpClient(object):

    EMPTY_TIMESTAMPS = {}
    DEFAULT_WATCH_COUNT_LIMIT = 5000

    def __init__(self, stub, loader, threadpool_max_size=10):
        self.stub = stub
        self.tp = threadpool.ThreadPool(maxsize=threadpool_max_size)
        self._enum_values_index = {}
        self.loader = loader

    def _get_object_values(self, object_type, object_id,
                           selectors=None, timestamp=None, ignore_nonexistent=False):
        """
        :type object_type: str
        :type object_id: str
        :type selectors: str | None
        :type timestamp: str | None
        :rtype: list[str]
        """
        req = object_service_pb2.TReqGetObject()
        req.object_type = object_type
        req.object_id = object_id
        if timestamp is not None:
            req.timestamp = timestamp
        selectors = selectors or DEFAULT_OBJECT_SELECTORS
        req.selector.paths.extend(selectors)
        if ignore_nonexistent:
            req.options.ignore_nonexistent = True
        req.format = self.loader.PAYLOAD_FORMAT
        return self.tp.apply(self.stub.GetObject, [req]).result.value_payloads

    def _select_objects_values(self, object_type, limit,
                               query=None, selectors=None,
                               timestamp=None, fetch_timestamps=False,
                               continuation_token=None):
        """
        :type object_type: str
        :type limit: int
        :type timestamp: int | None
        :type query: str | None
        :rtype: generator[list[bytes]]
        """
        req = object_service_pb2.TReqSelectObjects()
        req.object_type = object_type
        req.limit.value = limit
        if timestamp is not None:
            req.timestamp = timestamp
        selectors = selectors or ['']
        req.selector.paths.extend(selectors)
        if query:
            req.filter.query = query
        req.options.fetch_timestamps = fetch_timestamps
        if continuation_token:
            req.options.continuation_token = continuation_token
        req.format = self.loader.PAYLOAD_FORMAT
        resp = self.tp.apply(self.stub.SelectObjects, [req])
        rv = []
        for r in resp.results:
            rv.append(r.value_payloads)
        return rv, resp.continuation_token

    def _object_exists(self, object_type, object_id, timestamp=None):
        """
        :type object_type: str
        :type object_id: str
        :type timestamp: int | None
        :rtype: bool
        """
        v = self._get_object_values(object_id=object_id,
                                    object_type=object_type,
                                    timestamp=timestamp,
                                    ignore_nonexistent=True)
        return bool(v)

    def _create_object(self, object_pb, object_type, transaction_id=None):
        """
        :type object_pb: yp.data_model.T
        :type object_type: str
        :type transaction_id: str | None
        :rtype: tuple[str]
        """
        req = object_service_pb2.TReqCreateObject()

        if transaction_id:
            req.transaction_id = transaction_id

        req.object_type = object_type
        req.attributes = yt_yson_bindings.dumps_proto(object_pb)
        rsp = self.tp.apply(self.stub.CreateObject, [req])
        return rsp.object_id, rsp.fqid

    def _create_objects(self, objects_pb, object_type, transaction_id=None):
        """
        :type objects_pb: list[yp.data_model.T]
        :type object_type: str
        :type transaction_id: str | None
        :rtype: list[str]
        """
        req = object_service_pb2.TReqCreateObjects()

        if transaction_id:
            req.transaction_id = transaction_id

        for o in objects_pb:
            subreq = req.subrequests.add()
            subreq.object_type = object_type
            subreq.attributes = yt_yson_bindings.dumps_proto(o)

        rsp = self.tp.apply(self.stub.CreateObjects, [req])
        return [r.object_id for r in rsp.subresponses]

    def generate_timestamp(self):
        """
        :rtype: int
        """
        req = object_service_pb2.TReqGenerateTimestamp()
        resp = self.tp.apply(self.stub.GenerateTimestamp, [req])
        return resp.timestamp

    def start_transaction(self):
        """
        :rtype: (unicode, int)
        """
        req = object_service_pb2.TReqStartTransaction()
        resp = self.tp.apply(self.stub.StartTransaction, [req])
        return resp.transaction_id, resp.start_timestamp

    def commit_transaction(self, transaction_id):
        """
        :type transaction_id: unicode
        """
        req = object_service_pb2.TReqCommitTransaction()
        req.transaction_id = transaction_id
        self.tp.apply(self.stub.CommitTransaction, [req])

    def get_pod_set_ignore(self, rs_id, timestamp=None, selectors=None):
        """
        :type rs_id: str
        :type timestamp: int
        :rtype: yp.data_model.TPodSet
        """
        v = self._get_object_values(object_id=rs_id,
                                    object_type=data_model.OT_POD_SET,
                                    timestamp=timestamp,
                                    selectors=selectors,
                                    ignore_nonexistent=True)
        if not v:
            return None
        return self.loader.load_object(data_model.TPodSet, selectors, v)

    def get_replica_set_ignore(self, rs_id, selectors=None):
        """
        :type rs_id: str
        :rtype: yp.data_model.TReplicaSet
        """
        v = self._get_object_values(object_id=rs_id,
                                    object_type=data_model.OT_REPLICA_SET,
                                    selectors=selectors,
                                    ignore_nonexistent=True)
        if not v:
            return None
        return self.loader.load_object(data_model.TReplicaSet, selectors, v)

    def get_multi_cluster_replica_set_ignore(self, mcrs_id, selectors=None):
        """
        :type mcrs_id: str
        :rtype: yp.data_model.TMultiClusterReplicaSet
        """
        v = self._get_object_values(object_id=mcrs_id,
                                    object_type=data_model.OT_MULTI_CLUSTER_REPLICA_SET,
                                    selectors=selectors,
                                    ignore_nonexistent=True)
        if not v:
            return None
        return self.loader.load_object(data_model.TMultiClusterReplicaSet, selectors, v)

    def pod_set_exists(self, ps_id, timestamp=None):
        """
        :type ps_id: str
        :type timestamp: int | None
        :rtype: bool
        """
        return self._object_exists(object_type=data_model.OT_POD_SET,
                                   object_id=ps_id,
                                   timestamp=timestamp)

    def get_replica_set(self, rs_id, timestamp, selectors):
        """
        :type rs_id: str
        :type timestamp: int
        :rtype: yp.data_model.TReplicaSet
        """
        vals = self._get_object_values(object_id=rs_id,
                                       object_type=data_model.OT_REPLICA_SET,
                                       timestamp=timestamp,
                                       selectors=selectors)
        return self.loader.load_object(data_model.TReplicaSet, selectors, vals)

    def select_object_ids(self, object_type, batch_size, object_class, timestamp=None,
                          query=None):
        rv = []
        continuation_token = None
        while True:
            batch, continuation_token = self._select_objects_values(
                object_type=object_type,
                limit=batch_size,
                query=query,
                timestamp=timestamp,
                selectors=self.loader.ID_SELECTORS,
                fetch_timestamps=False,
                continuation_token=continuation_token,
            )
            for values in batch:
                obj = self.loader.load_object_attrs(object_class,
                                                    self.loader.ID_ATTR_PATHS,
                                                    values)
                rv.append(obj.meta.id)
            if len(batch) < batch_size:
                break
        return rv

    def watch_objects(self, object_type, start_timestamp, selectors,
                      event_count_limit, timestamp, time_limit_seconds, query=None):
        req = object_service_pb2.TReqWatchObjects()
        req.object_type = object_type
        req.start_timestamp = start_timestamp
        req.timestamp = timestamp
        req.event_count_limit = event_count_limit
        req.time_limit.seconds = time_limit_seconds
        req.selector.paths.extend(selectors)
        if query:
            req.filter.query = query
        events = []
        while True:
            resp = self.tp.apply(self.stub.WatchObjects, [req])
            events.extend(resp.events)
            if len(resp.events) < event_count_limit:
                return events
            req.continuation_token = resp.continuation_token
            # start_timestamp and continuation_token are both used to set left
            # border and cannot be used together. For the first query we use
            # start_timestamp, but after that we use continuation_token to
            # paginate.
            req.ClearField(b'start_timestamp')

    def get_objects(self, object_type, object_class, ids,
                    selectors=None,
                    ignore_nonexistent=False, fetch_timestamps=False,
                    timestamp=None):
        req = object_service_pb2.TReqGetObjects()
        req.object_type = object_type
        for object_id in ids:
            subreq = req.subrequests.add()
            subreq.object_id = object_id
        if timestamp is not None:
            req.timestamp = timestamp
        selectors = selectors or DEFAULT_OBJECT_SELECTORS
        req.selector.paths.extend(selectors)
        req.format = self.loader.PAYLOAD_FORMAT
        req.options.ignore_nonexistent = ignore_nonexistent
        req.options.fetch_timestamps = fetch_timestamps
        resp = self.tp.apply(self.stub.GetObjects, [req])
        rv = []
        paths = make_attr_paths(selectors)
        for subresp in resp.subresponses:
            if ignore_nonexistent and not subresp.HasField('result'):
                continue
            obj = self.loader.load_object_attrs(object_class,
                                                paths,
                                                subresp.result.value_payloads)
            if fetch_timestamps:
                timestamps = dict(zip(selectors, subresp.result.timestamps))
                rv.append((obj, timestamps))
            else:
                rv.append((obj, self.EMPTY_TIMESTAMPS))
        return rv

    def create_pod_set(self, ps, transaction_id):
        """
        :type ps: yp.data_model.TPodSet
        :type transaction_id: str
        :rtype: str
        """
        return self._create_object(object_pb=ps,
                                   object_type=data_model.OT_POD_SET,
                                   transaction_id=transaction_id)

    def create_pods(self, pods, transaction_id):
        """
        :type pods: list[yp.data_model.TPod]
        :type transaction_id: str
        :rtype list[str]
        """
        return self._create_objects(objects_pb=pods,
                                    object_type=data_model.OT_POD,
                                    transaction_id=transaction_id)

    def update_pod_set(self, ps_id, template):
        """
        :type ps_id: str
        :type template: yp.data_model.TPodSet
        """
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = data_model.OT_POD_SET
        req.object_id = ps_id

        upd = req.set_updates.add()
        upd.path = '/spec/antiaffinity_constraints'
        antiaffinity = []
        for c in template.spec.antiaffinity_constraints:
            antiaffinity.append(yp.common.protobuf_to_dict(c))
        upd.value = yson.dumps(antiaffinity)

        upd = req.set_updates.add()
        upd.path = '/spec/node_segment_id'
        upd.value = yson.dumps(template.spec.node_segment_id)

        upd = req.set_updates.add()
        upd.path = '/spec/account_id'
        upd.value = yson.dumps(template.spec.account_id)

        upd = req.set_updates.add()
        upd.path = '/meta/acl'
        acl = []
        for entry in template.meta.acl:
            acl.append(yp.common.protobuf_to_dict(entry))
        upd.value = yson.dumps(acl)

        self.tp.apply(self.stub.UpdateObject, [req])

    @staticmethod
    def _add_pod_to_update_pods_request(req, pod):
        subreq = req.subrequests.add()
        subreq.object_type = data_model.OT_POD
        subreq.object_id = pod.meta.id

        upd = subreq.set_updates.add()
        upd.path = '/spec'
        upd.value = yt_yson_bindings.dumps_proto(pod.spec)

        if not pod.spec.secrets:
            upd = subreq.set_updates.add()
            upd.path = '/spec/secrets'
            upd.value = yt_yson_bindings.dumps({})

        for attr in pod.labels.attributes:
            upd = subreq.set_updates.add()
            upd.path = '/labels/{}'.format(attr.key)
            upd.value = attr.value

        upd = subreq.set_updates.add()
        upd.path = '/annotations'
        upd.value = yt_yson_bindings.dumps_proto(pod.annotations)

    def _make_update_pods_threads(self, pods, batch_size):
        if len(pods) == 0:
            return

        reqs = []
        exceptions = []
        req = object_service_pb2.TReqUpdateObjects()

        for i, pod in geventutil.gevent_idle_iter(enumerate(pods, 1)):
            try:
                self._add_pod_to_update_pods_request(req, pod)
            except Exception as e:
                exceptions.append(e)
                continue
            if i % batch_size == 0 or i == len(pods):
                reqs.append(req)
                req = object_service_pb2.TReqUpdateObjects()

        threads = []
        for req in geventutil.gevent_idle_iter(reqs):
            t = self.tp.spawn(self.stub.UpdateObjects, req)
            threads.append(t)

        return threads, exceptions

    def update_pods(self, pods, batch_size):
        threads, exceptions = self._make_update_pods_threads(pods, batch_size)
        for t in threads:
            try:
                t.get()
            except Exception as e:
                exceptions.append(e)
        if exceptions:
            raise exceptions[0]

    def safe_update_pods(self, pods):
        threads, exceptions = self._make_update_pods_threads(pods, batch_size=1)
        failed_pods = []
        for i, t in enumerate(threads):
            try:
                t.get()
            except yp.common.YpPodSchedulingFailure:
                failed_pods.append(pods[i])
            except Exception as e:
                exceptions.append(e)

        if exceptions:
            raise exceptions[0]

        return failed_pods

    def _add_deploy_child(self, object_type, object_id, child):
        """
        :type object_type: str
        :type object_id: str
        :type child: str
        """
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = object_type
        req.object_id = object_id
        upd = req.set_updates.add()
        upd.path = '/control/add_deploy_child'
        upd.value = yson.dumps({'fqid': child})
        self.tp.apply(self.stub.UpdateObject, [req])

    def _remove_deploy_child(self, object_type, object_id, child):
        """
        :type object_type: str
        :type object_id: str
        :type child: str
        """
        try:
            req = object_service_pb2.TReqUpdateObject()
            req.object_type = object_type
            req.object_id = object_id
            upd = req.set_updates.add()
            upd.path = '/control/remove_deploy_child'
            upd.value = yson.dumps({'fqid': child})
            self.tp.apply(self.stub.UpdateObject, [req])
        except yp.common.YpNoSuchTransactionError:
            pass

    def add_deploy_child_to_rs(self, rs_id, child):
        """
        :type rs_id: str
        :type child: str
        """
        self._add_deploy_child(data_model.OT_REPLICA_SET, rs_id, child)

    def add_deploy_child_to_mcrs(self, mcrs_id, child):
        """
        :type mcrs_id: str
        :type child: str
        """
        self._add_deploy_child(data_model.OT_MULTI_CLUSTER_REPLICA_SET, mcrs_id, child)

    def remove_deploy_child_from_rs(self, rs_id, child):
        """
        :type rs_id: str
        :type child: str
        """
        self._remove_deploy_child(data_model.OT_REPLICA_SET, rs_id, child)

    def remove_deploy_child_from_mcrs(self, mcrs_id, child):
        """
        :type mcrs_id: str
        :type child: str
        """
        self._remove_deploy_child(data_model.OT_MULTI_CLUSTER_REPLICA_SET, mcrs_id, child)

    def update_pods_request_eviction(self, pod_ids, msg):
        """
        :type pod_ids: list[str]
        :type msg: str
        """
        req = object_service_pb2.TReqUpdateObjects()
        yson_value = yson.dumps({'message': msg})
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            subreq = req.subrequests.add()
            subreq.object_type = data_model.OT_POD
            subreq.object_id = p_id
            upd = subreq.set_updates.add()
            upd.path = '/control/request_eviction'
            upd.value = yson_value
        self.tp.apply(self.stub.UpdateObjects, [req])

    def mark_delegate_removing_pods(self, pod_ids):
        """
        :type pod_ids: list[str]
        """
        req = object_service_pb2.TReqUpdateObjects()
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            subreq = req.subrequests.add()
            subreq.object_type = data_model.OT_POD
            subreq.object_id = p_id
            upd = subreq.set_updates.add()
            upd.path = '/labels/{}'.format(DELEGATE_REMOVING_LABEL)
            upd.value = yson.dumps(True)
        self.tp.apply(self.stub.UpdateObjects, [req])

    def delegate_removing_pods(self, pod_ids, msg):
        """
        :type pod_ids: list[str]
        :type msg: str
        """
        req = object_service_pb2.TReqUpdateObjects()
        yson_value = yson.dumps({'message': msg})
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            subreq = req.subrequests.add()
            subreq.object_type = data_model.OT_POD
            subreq.object_id = p_id
            upd = subreq.set_updates.add()
            upd.path = '/control/request_eviction'
            upd.value = yson_value
            upd = subreq.set_updates.add()
            upd.path = '/labels/{}'.format(DELEGATE_REMOVING_LABEL)
            upd.value = yson.dumps(True)
        self.tp.apply(self.stub.UpdateObjects, [req])

    def update_pods_acknowledge_eviction(self, pod_ids, msg, use_evict=False):
        """
        :type pod_ids: list[str]
        :type msg: str
        :type use_evict: bool
        """
        req = object_service_pb2.TReqUpdateObjects()
        yson_value = yson.dumps({'message': msg})
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            subreq = req.subrequests.add()
            subreq.object_type = data_model.OT_POD
            subreq.object_id = p_id
            upd = subreq.set_updates.add()
            upd.path = '/control/evict' if use_evict else '/control/acknowledge_eviction'
            upd.value = yson_value
        self.tp.apply(self.stub.UpdateObjects, [req])

    def update_pods_acknowledge_maintenance(self, pod_ids, msg):
        """
        :type pod_ids: list[str]
        :type msg: str
        """
        req = object_service_pb2.TReqUpdateObjects()
        yson_value = yson.dumps({'message': msg})
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            subreq = req.subrequests.add()
            subreq.object_type = data_model.OT_POD
            subreq.object_id = p_id
            upd = subreq.set_updates.add()
            upd.path = '/control/acknowledge_maintenance'
            upd.value = yson_value
        self.tp.apply(self.stub.UpdateObjects, [req])

    def update_pods_target_state(self, pod_ids, target_state):
        """
        :type pod_ids: list[str]
        :type target_state: int
        """
        req = object_service_pb2.TReqUpdateObjects()
        proto_enum = yp.client.to_proto_enum_by_number(yp.data_model.EPodAgentTargetState, target_state)
        proto_enum_value = yp.client.get_proto_enum_value_name(proto_enum)
        yson_value = yson.dumps(proto_enum_value)
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            subreq = req.subrequests.add()
            subreq.object_type = data_model.OT_POD
            subreq.object_id = p_id
            upd = subreq.set_updates.add()
            upd.path = '/spec/pod_agent_payload/spec/target_state'
            upd.value = yson_value
        self.tp.apply(self.stub.UpdateObjects, [req])

    def update_replica_set_status(self, rs_id, status):
        """
        :type rs_id: str
        :type status: yp.data_model.TReplicaSetStatus
        """
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = data_model.OT_REPLICA_SET
        req.object_id = rs_id
        upd = req.set_updates.add()
        upd.path = '/status'
        upd.value = yt_yson_bindings.dumps_proto(status)
        self.tp.apply(self.stub.UpdateObject, [req])

    def update_replica_set_annotations(self, rs_id, annotations, transaction_id):
        """
        :type rs_id: str
        :type annotations: yp.data_model.TAttributeDictionary
        """
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = data_model.OT_REPLICA_SET
        req.object_id = rs_id
        req.transaction_id = transaction_id

        upd = req.set_updates.add()
        upd.path = '/annotations'
        upd.value = yt_yson_bindings.dumps_proto(annotations)
        self.tp.apply(self.stub.UpdateObject, [req])

    def update_multi_cluster_replica_set_status(self, mcrs_id, status):
        """
        :type mcrs_id: str
        :type status: yp.data_model.TMultiClusterReplicaSetStatus
        """
        req = object_service_pb2.TReqUpdateObject()
        req.object_type = data_model.OT_MULTI_CLUSTER_REPLICA_SET
        req.object_id = mcrs_id
        upd = req.set_updates.add()
        upd.path = '/status'
        upd.value = yt_yson_bindings.dumps_proto(status)
        self.tp.apply(self.stub.UpdateObject, [req])

    def remove_pods(self, pod_ids, transaction_id=None):
        """
        :type pod_ids: list[str]
        :type transaction_id: str | None
        """
        req = object_service_pb2.TReqRemoveObjects()
        if transaction_id:
            req.transaction_id = transaction_id
        for p_id in geventutil.gevent_idle_iter(pod_ids):
            sub = req.subrequests.add()
            sub.object_type = data_model.OT_POD
            sub.object_id = p_id
        self.tp.apply(self.stub.RemoveObjects, [req])

    def remove_pod_set(self, pod_set_id, transaction_id=None):
        """
        :type pod_set_id: str
        :type transaction_id: str
        """
        req = object_service_pb2.TReqRemoveObject()
        req.object_type = data_model.OT_POD_SET
        req.object_id = pod_set_id
        if transaction_id is not None:
            req.transaction_id = transaction_id
        self.tp.apply(self.stub.RemoveObject, [req])
