import argparse
import sys
import socket
import logging
from string import Formatter

from infra.yp_service_discovery.api import api_pb2
from infra.yp_service_discovery.python.resolver.resolver import Resolver
from google.protobuf.json_format import MessageToDict


CLUSTERS = ['iva', 'myt', 'sas', 'man', 'vla', 'xdc']
DEFAULT_LEVEL = logging.WARNING


def parse_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--append-host-to-client", type=bool, help="Append current host name to client_name in request")
    parser.add_argument("--cluster", nargs='*', default=CLUSTERS, choices=CLUSTERS, help="Cluster name.")
    parser.add_argument("--address", default='sd.yandex.net:8081')
    parser.add_argument("--endpoint-set-id", required=True, help="Endpoint set id.")
    parser.add_argument("--client-name", required=True, help="Client name.")
    parser.add_argument("--labels", nargs='*', help="Additional lables"),
    parser.add_argument("--filter", help="Filter output")
    parser.add_argument("--entry-formatter", help="Entry formatter")
    parser.add_argument("--join", help="Join entries with separator")
    parser.add_argument("--comment", help="Comment for SD request")
    parser.add_argument("--timeout", type=int, default=5, help="Timeout")
    logging_args = parser.add_mutually_exclusive_group()
    logging_args.add_argument("-v",  "--verbose", action="store_const", const=logging.INFO, dest="log_level", help="verbose mode")
    logging_args.add_argument("-d",  "--debug", action="store_const", const=logging.DEBUG, dest="log_level", help="verbose mode")
    logging_args.add_argument("-q",  "--quiet", action="store_const", const=logging.ERROR, dest="log_level", help="verbose mode")
    result = parser.parse_args(argv)
    logging.basicConfig(stream=sys.stderr, level=getattr(result, "log_level", DEFAULT_LEVEL))
    if result.append_host_to_client:
        result.client_name = '{client}:{hostname}'.format(client=result.client_name, hostname=socket.gethostname())
        logging.debug(f'Using client {result.client_name}')
    if result.entry_formatter:
        fieldnames = [fname for _, fname, _, _ in Formatter().parse(result.entry_formatter) if fname]
        logging.debug(f'Got fields {",".join(fieldnames)} in formatter string')
        availiable_names = ([f.json_name for f in api_pb2.TEndpoint.DESCRIPTOR.fields])
        logging.debug(f'Availiable fields is {",".join(fieldnames)} (from TEndpoint descriptor)')
        bad_names = []
        for name in fieldnames:
            if name not in availiable_names:
                bad_names.append(name)
        if bad_names:
            raise Exception('Bad arguments. Available keys is {fields}'.format(fields=', '.join(availiable_names)))
    return result


def resolve(args, cluster=None):
    if cluster is None:
        cluster = args.cluster

    resolver = Resolver(grpc_address=args.address, client_name=args.client_name, timeout=args.timeout)

    request = api_pb2.TReqResolveEndpoints()
    request.cluster_name = cluster
    logging.debug(f'Clusters {cluster}')
    request.endpoint_set_id = args.endpoint_set_id
    logging.debug(f'Endpoint_set_id {args.endpoint_set_id}')
    if args.labels:
        logging.debug(f'Append lables {args.labels}')
        request.label_selectors.extend(args.labels)
    if args.comment:
        logging.debug(f'Comment {args.comment}')
        request.ruid = args.comment
    response = resolver.resolve_endpoints(request)
    logging.info(f'Timestamp: {response.timestamp}')
    logging.info(f'Endpoint-set_id: {response.endpoint_set.endpoint_set_id}')
    logging.debug(f'Full response:\n{response}')
    result = MessageToDict(response)
    logging.debug(f'Got response: {result}')
    if 'endpointSet' in result:
        return result['endpointSet'].get('endpoints', [])
    else:
        return []


def main(argv):
    args = parse_args(argv)
    if not args.cluster:
        return ''
    result = []
    for cluster in args.cluster:
        result.extend(resolve(args, cluster))

    if args.entry_formatter:
        result = list(map(lambda x: args.entry_formatter.format(**x), result))
    if args.join:
        result = args.join.join(result)
    print(result)

if __name__ == "__main__":
    main(sys.argv[1:])
