import grpc
import json
from concurrent import futures
from google.protobuf import json_format
from collections import defaultdict

from infra.yp_service_discovery.api import api_pb2_grpc, api_pb2


class SdServiceServicer(api_pb2_grpc.TServiceDiscoveryServiceServicer):
    def __init__(self, cluster_to_endpoints_response, cluster_to_pods_response):
        """

        :type cluster_to_endpoints_response: dict[str, api_pb2.TRspResolveEndpoints]
        :type cluster_to_pods_response: dict[str, api_pb2.TRspResolvePods]
        """
        self.cluster_to_endpoints_response = cluster_to_endpoints_response
        self.cluster_to_pods_response = cluster_to_pods_response

    def ResolveEndpoints(self, request, _):
        return self.cluster_to_endpoints_response[request.cluster_name]

    def ResolvePods(self, request, _):
        return self.cluster_to_pods_response[request.cluster_name]


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('bind_addr')
    parser.add_argument('--endpoints-responses', default=None)
    parser.add_argument('--pods-responses', default=None)
    args = parser.parse_args()

    cluster_to_endpoints_response = defaultdict(lambda: api_pb2.TRspResolveEndpoints())
    cluster_to_pods_response = defaultdict(lambda: api_pb2.TRspResolvePods())

    if args.endpoints_responses:
        data = json.loads(args.endpoints_responses)
        for cluster_name, response_data in data.iteritems():
            json_format.ParseDict(response_data, cluster_to_endpoints_response[cluster_name],
                                  ignore_unknown_fields=True)
    if args.pods_responses:
        data = json.loads(args.pods_responses)
        for cluster_name, response_data in data.iteritems():
            json_format.ParseDict(response_data, cluster_to_pods_response[cluster_name],
                                  ignore_unknown_fields=True)

    servicer = SdServiceServicer(cluster_to_endpoints_response, cluster_to_pods_response)
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
    api_pb2_grpc.add_TServiceDiscoveryServiceServicer_to_server(servicer, server)
    server.add_insecure_port(args.bind_addr)
    server.start()
    try:
        server.wait_for_termination()
    except KeyboardInterrupt:
        pass
