# coding: utf8
from __future__ import print_function

import collections
import datetime
import logging
import socket
import uuid

import grpc
from yp.client import YpClient, find_token

from travel.orders.tools.library.tvm import get_tvm_service_ticket

clusters = ["vla", "sas", "man"]

config = {
    'prod': {
        'tvm_service_id': 2018392
    },
    'testing': {
        'tvm_service_id': 2018390
    },
    'dev': {
        'tvm_service_id': 2018390
    }
}


class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
    def __init__(self, interceptor_function):
        self._fn = interceptor_function

    def intercept_unary_unary(self, continuation, client_call_details, request):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, iter((request,)), False, False)
        response = continuation(new_details, next(new_request_iterator))
        return postprocess(response) if postprocess else response

    def intercept_unary_stream(self, continuation, client_call_details, request):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, iter((request,)), False, True)
        response_it = continuation(new_details, next(new_request_iterator))
        return postprocess(response_it) if postprocess else response_it

    def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, request_iterator, True, False)
        response = continuation(new_details, new_request_iterator)
        return postprocess(response) if postprocess else response

    def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, request_iterator, True, True)
        response_it = continuation(new_details, new_request_iterator)
        return postprocess(response_it) if postprocess else response_it


class _ClientCallDetails(collections.namedtuple('_ClientCallDetails', ('method', 'timeout', 'metadata', 'credentials')),
                         grpc.ClientCallDetails):
    pass


def resolve_hosts(args):
    endpoint_set_id = f"travel-hotels-administrator-{args.env}"
    for cluster in clusters:
        with YpClient(cluster, config=dict(token=find_token())) as yp_client:
            for _, result in yp_client.select_objects("endpoint", selectors=["/meta", "/spec"],
                                                      filter=f"[/meta/endpoint_set_id]=\"{endpoint_set_id}\""):
                yield result['fqdn']


def create_grpc_channel_raw(host='localhost', port=29855, timeout=5, env='dev', tvm_client_id=None, tvm_client_secret=None):
    if env == 'dev':
        ticket = ''
    else:
        ticket = get_tvm_service_ticket(
            tvm_client_id=tvm_client_id,
            tvm_service_id=config[env]['tvm_service_id'],
            tvm_client_secret=tvm_client_secret,
            skip_authentication=False
        )

    def ya_grpc_interceptor(client_call_details, request_iterator,
                            request_streaming, response_streaming):
        ya_call_id = str(uuid.uuid1())
        ya_started_at = datetime.datetime.utcnow().isoformat() + "Z"
        ya_fqdn = socket.getfqdn()

        metadata = {}
        if client_call_details.metadata is not None:
            metadata = dict(client_call_details.metadata)
        metadata["ya-grpc-call-id"] = ya_call_id
        metadata["ya-grpc-started-at"] = ya_started_at
        metadata["ya-grpc-fqdn"] = ya_fqdn
        metadata["x-ya-service-ticket"] = ticket

        if client_call_details.timeout is not None:
            call_timeout = client_call_details.timeout
        else:
            call_timeout = timeout

        client_call_details = _ClientCallDetails(
            client_call_details.method,
            call_timeout,
            list(metadata.items()),
            client_call_details.credentials)

        assert not request_streaming
        assert not response_streaming

        request = next(request_iterator)
        logging.info("<- REQ %s (CallId: %s)", client_call_details.method, ya_call_id)
        logging.debug("%s\n%r", request.__class__, request)

        def complete(rendezvous):
            logging.info("-> RSP %s (CallId: %s, Code: %s, Cancelled: %s)",
                         client_call_details.method, ya_call_id,
                         rendezvous.code(), rendezvous.cancelled())

            metadata = {}
            for item in rendezvous.initial_metadata():
                metadata[item.key] = item.value
            for item in rendezvous.trailing_metadata():
                metadata[item.key] = item.value
            for key, value in sorted(metadata.items()):
                logging.debug("Meta: [%s] => %r", key, value)

            if not rendezvous.cancelled():
                response = rendezvous.result()
                logging.debug("%s\n%r", response.__class__, response)

        def postprocess(rendezvous):
            complete(rendezvous)
            return rendezvous

        return client_call_details, iter((request,)), postprocess

    endpoint = "%s:%s" % (host, port)
    logging.info("Creating a channel to '%s'", endpoint)
    channel = grpc.insecure_channel(endpoint)
    channel = grpc.intercept_channel(channel, _GenericClientInterceptor(ya_grpc_interceptor))
    return channel


def create_grpc_channel(args):
    if args.host:
        host = args.host
    elif args.env == 'dev':
        host = 'localhost'
    else:
        host = next(resolve_hosts(args))

    return create_grpc_channel_raw(
        host=host,
        port=args.port,
        timeout=args.timeout,
        env=args.env,
        tvm_client_id=args.tvm_client_id,
        tvm_client_secret=args.tvm_client_secret
    )
