# coding: utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

import six

from requests.exceptions import ReadTimeout

from infra.nanny.yp_lite_api.proto import endpoint_sets_pb2
from infra.nanny.yp_lite_api.proto import endpoint_sets_api_pb2
from infra.nanny.yp_lite_api.py_stubs.endpoint_sets_api_stub import YpLiteUIEndpointSetsServiceStub

from saas.library.python.nanny_proto.rpc_client_base import NannyRpcClientBase
from .errors import ApiRequestTimeout


class EndpointSet(NannyRpcClientBase):
    _identity = ['cluster', 'endpoint_set_id']
    _RPC_URL = 'https://yp-lite-ui.nanny.yandex-team.ru/api/yplite/endpoint-sets/'
    _API_STUB = YpLiteUIEndpointSetsServiceStub

    def __init__(self, cluster, endpoint_set_id, proto_msg=None):
        super(EndpointSet, self).__init__()
        self._cluster = self.normalize_cluster(cluster)
        self._id = endpoint_set_id
        self._proto_msg = proto_msg

    def __repr__(self):
        return 'EndpointSet({}, {})'.format(self._cluster, self._id)

    def __str__(self):
        return '{}@{}'.format(self._cluster.lower(), self._id)

    @staticmethod
    def normalize_cluster(cluster):
        return cluster.upper().replace('-', '_')

    @classmethod
    def create(cls, clusters, service_id, endpoint_set_id, pod_filter='', protocol='tcp', port=80, description=''):
        cls._init_client()
        result = []
        meta = endpoint_sets_pb2.EndpointSetMeta(id=endpoint_set_id, service_id=service_id)
        spec = endpoint_sets_pb2.EndpointSetSpec(pod_filter=pod_filter, protocol=protocol, port=port, description=description)
        for cluster in clusters:
            cluster = cls.normalize_cluster(cluster)
            req = endpoint_sets_api_pb2.CreateEndpointSetRequest(meta=meta, spec=spec, cluster=cluster)
            cls.LOGGER.info('Creating EndpointSet %s in %s with request %s', endpoint_set_id, cluster, req)
            res = cls._CLIENT.create_endpoint_set(req).endpoint_set
            result.append(EndpointSet(cluster, res.meta.id, res))
        return result

    @classmethod
    def search(cls, cluster, substring, limit=100):
        cls._init_client()
        cluster = cls.normalize_cluster(cluster)
        req = endpoint_sets_api_pb2.SearchEndpointSetsRequest(substring=substring, limit=limit, cluster=cluster)
        res = cls._CLIENT.search_endpoint_sets(req)  # type: endpoint_sets_api_pb2.SearchEndpointSetsResponse
        return [cls(cluster, endpoint_set.id) for endpoint_set in res.endpoint_sets]

    @classmethod
    def list(cls, cluster, service_id):
        cls._init_client()
        cluster = cls.normalize_cluster(cluster)
        req = endpoint_sets_api_pb2.ListEndpointSetsRequest(service_id=service_id, cluster=cluster)
        res = cls._CLIENT.list_endpoint_sets(req)  # type: endpoint_sets_api_pb2.ListEndpointSetsResponse
        return [cls(cluster, endpoint_set.meta.id, endpoint_set) for endpoint_set in res.endpoint_sets]

    @property
    def cluster(self):
        return self._cluster

    @property
    def id(self):
        return self._id

    @property
    def proto_msg(self):
        if self._proto_msg is None:
            req = endpoint_sets_api_pb2.GetEndpointSetRequest(id=self._id, cluster=self._cluster)
            self._proto_msg = self._CLIENT.get_endpoint_set(req).endpoint_set
        return self._proto_msg

    @property
    def service_id(self):
        return self.proto_msg.meta.service_id

    @property
    def version(self):
        return self.proto_msg.meta.version

    @property
    def pod_filter(self):
        return self.proto_msg.spec.pod_filter

    @property
    def protocol(self):
        return self.proto_msg.spec.protocol

    @property
    def port(self):
        return self.proto_msg.spec.port

    @property
    def description(self):
        return self.proto_msg.spec.description

    def update(self, **kwargs):
        changes = {}
        spec = endpoint_sets_pb2.EndpointSetSpec(
            pod_filter=self._proto_msg.spec.pod_filter,
            protocol=self._proto_msg.spec.protocol,
            port=self._proto_msg.spec.port,
            description=self._proto_msg.spec.description
        )
        for k, v in six.iteritems(kwargs):
            if v is not None and getattr(spec, k) != v:
                changes[k] = v
        if changes:
            self.LOGGER.info('Updating %s; request: %s; changes: %s', self, kwargs, changes)
            for k, v in six.iteritems(changes):
                setattr(spec, k, v)
            req = endpoint_sets_api_pb2.UpdateEndpointSetRequest(id=self._id, version=self.version, spec=spec, cluster=self._cluster)
            self._proto_msg = self._CLIENT.update_endpoint_set(req).endpoint_set
        else:
            self.LOGGER.warning('No changes in %s for %s', kwargs, self)

    def remove(self):
        req = endpoint_sets_api_pb2.RemoveEndpointSetRequest(id=self._id, version=self.version, cluster=self._cluster)
        try:
            self._CLIENT.remove_endpoint_set(req)
        except ReadTimeout as e:
            raise ApiRequestTimeout(request=req, inner_exception=e)

    @pod_filter.setter
    def pod_filter(self, value):
        self.update(pod_filter=value)

    @protocol.setter
    def protocol(self, value):
        self.update(protocol=value)

    @description.setter
    def description(self, value):
        self.update(description=value)
