import logging
import pickle
import sys
import tempfile
import time
import traceback

import six
import grpc

from collections import namedtuple
from concurrent import futures
from google.protobuf import empty_pb2

from sandbox.common import patterns
from tasklet.api import tasklet_pb2

from . import dispatch
from . import utils


logger = logging.getLogger(__name__)

ServiceInfo = namedtuple("ServiceInfo", "ref entry descr")
# ref: ServiceRef, entry: ContextEntry,  descr: LocalService


class LocalGrpcServer(object):

    def __init__(self):
        self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        self.address = "unix://{}".format(tempfile.mktemp(prefix="tasklet_server_"))
        self._server.add_insecure_port(self.address)
        self.channel = grpc.insecure_channel(self.address)
        self._registered_services = {}
        self._started = False

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

    def __hash__(self):
        return hash(self.address)

    def start(self):
        if not self._started:
            self._server.start()
        self._started = True

    def stop(self):
        """
        Stops the server. May only be called once
        """
        if self._started:
            self.channel.close()
            self._server.stop(None)
        self._started = False

    def _register_local_service(self, local_service):
        fingerprint = local_service.SerializeToString()
        if fingerprint in self._registered_services:
            logger.debug(
                "Service already %s registered via %s [address %s]",
                local_service.impl,
                local_service.register,
                self.address,
            )
            return self._registered_services[fingerprint]

        logger.debug(
            "Registering service %s via %s [address %s]", local_service.impl, local_service.register, self.address
        )
        register = utils.import_symbol(local_service.register)
        impl = utils.import_symbol(local_service.impl)
        register(impl(), self._server)

        ref = tasklet_pb2.ServiceRef(address=self.address, client=local_service.client)
        self._registered_services[fingerprint] = ref
        return ref

    def register_services(self, entries):
        """
        Registers local services from the given list of entries.

        :param Iterateable entries: contain both tasklet_pb2.ServiceRef and
            tasklet_pb2.LocalService entries to setup on current grpc server

        :rtype: Dict: {field_name: ServiceInfo}
        """

        field_name_2_info = {}

        local_services = []

        for entry in entries:  # type: tasklet_pb2.ContextEntry
            if entry.any.Is(tasklet_pb2.LocalService.DESCRIPTOR):
                new_service = tasklet_pb2.LocalService()
                entry.any.Unpack(new_service)
                local_services.append(new_service)

                service_ref = self._register_local_service(new_service)
                field_name_2_info[entry.name] = ServiceInfo(service_ref, entry, new_service)
            else:
                raise RuntimeError("Context entry {} is not a local service".format(entry.name))
        return field_name_2_info

    def provide_client_context_for_service(self, service_descr, service_name_2_refs):
        client_cls = utils.import_symbol(service_descr.client)
        client_obj = client_cls(self.channel)
        if hasattr(client_obj, "GetContext"):
            client_context = client_obj.GetContext(empty_pb2.Empty())
            for field in client_context.DESCRIPTOR.fields:
                if field.message_type.name in service_name_2_refs:
                    if len(service_name_2_refs[field.message_type.name]) > 1:
                        raise RuntimeError("Multiple requested services providing is deprecated: {}".format(field.message_type.name))
                    requested = getattr(client_context, field.name)
                    service_name_2_refs[field.message_type.name][0].any.Unpack(requested.ref)
                else:
                    raise RuntimeError("Requested context field {} has not been provided".format(field.name))

            client_obj.Inject(client_context)


class LocalGrpcServerPool(six.with_metaclass(patterns.ThreadSafeSingletonMeta, object)):

    def __init__(self):
        self._servers = []
        self._address_2_names = {}

    def setup_services(self, ctx):
        """
        :param tasklet_pb2.Context ctx: contain both tasklet_pb2.ServiceRef and
            tasklet_pb2.LocalService entries to setup on current grpc server

        :rtype: tasklet_pb2.Context: object with all tasklet_pb2.LocalService entries
            changed to their according tasklet_pb2.ServiceRef
        """

        entries_groups = self._split_ident_entries(ctx.entries)

        while len(entries_groups) > len(self._servers):
            self._add_server()

        field_name_2_info = {}
        for entries, grpc_server in six.moves.zip(entries_groups, self._servers):
            field_name_2_info_local = grpc_server.register_services(entries)
            self._address_2_names[grpc_server.address] = [info.entry.name for info in field_name_2_info_local.values()]
            field_name_2_info.update(field_name_2_info_local)

        result_ctx = self._prepare_context(field_name_2_info)

        service_name_2_entry = {}
        for entry in result_ctx.entries:
            if entry.service_name_ctx not in service_name_2_entry:
                service_name_2_entry[entry.service_name_ctx] = []
            service_name_2_entry[entry.service_name_ctx].append(entry)

        for entry_group in service_name_2_entry.values():
            for entry in entry_group:
                for grpc_server in self._servers:
                    if entry.name in self._address_2_names[grpc_server.address]:
                        grpc_server.provide_client_context_for_service(
                            field_name_2_info[entry.name].descr, service_name_2_entry
                        )

        return result_ctx

    @staticmethod
    def _prepare_context(field_name_2_info):
        result = tasklet_pb2.Context()
        for name, info in field_name_2_info.items():
            new = result.entries.add()
            new.name = name
            new.any.Pack(info.ref)
            new.service_name_ctx = info.entry.service_name_ctx
        return result

    def _add_server(self):
        grpc_server = LocalGrpcServer()
        grpc_server.start()
        self._servers.append(grpc_server)

    @staticmethod
    def _split_ident_entries(entries):
        """
        :param tasklet_pb2.Context ctx: contain both tasklet_pb2.ServiceRef and tasklet_pb2.LocalService entries
            to split for use on different grpc servers, because one server can contain only one entry of given
            service_name_ref (MessageType)
        :rtype: list of lists: contain splitted entries
        """
        ident_services_groups = [[]]
        ident_services_service_name_ctx = [set()]

        for entry in entries:
            service_inserted = False
            for group_id, group in enumerate(ident_services_groups):
                if entry in group:
                    break
                if entry.service_name_ctx not in ident_services_service_name_ctx[group_id]:
                    ident_services_service_name_ctx[group_id].add(entry.service_name_ctx)
                    ident_services_groups[group_id].append(entry)
                    service_inserted = True
            if not service_inserted:
                ident_services_groups.append([entry])
                ident_services_service_name_ctx.append({entry.service_name_ctx})

        return ident_services_groups


def setup_python_error(error_msg):
    error_msg.success = False

    et, ev, tb = sys.exc_info()

    if isinstance(ev, grpc.RpcError):
        error_msg.error = ev.details()
    else:
        try:
            error_msg.python_error.et = utils.symbol_path(et)
            error_msg.python_error.ev = pickle.dumps(ev)
            error_msg.python_error.tb = "".join(traceback.format_exception(et, ev, tb))

            error_msg.is_python_error = True
        except pickle.PicklingError:
            error_msg.error = repr(ev)


def launch(job):
    """
    Executes a given tasklet

    :param tasklet_pb2.JobInstance job:
    :rtype: tasklet_pb2.JobResult
    """
    with_id = " with id={}".format(job.id) if job.id else ""
    logger.info("Executing '%s' tasklet%s", job.statement.name, with_id)
    start_time = time.time()
    try:
        r = dispatch.dispatch(job.SerializeToString())
    finally:
        logger.info("'%s' tasklet%s was executed in %0.2fs", job.statement.name, with_id, time.time() - start_time)

    response = tasklet_pb2.JobResult()
    response.ParseFromString(r)

    return response


# Function is called from cpp code.
def execute_helper(impl_path, data):
    impl_path = six.ensure_text(impl_path)
    logger.debug("Execute %s with request size %s", six.ensure_text(impl_path), len(data))
    response = tasklet_pb2.JobResult()
    try:
        response.CopyFrom(_execute_helper(impl_path, data))
    except Exception:
        setup_python_error(response)
        response.error = "Something went really wrong, please contact tasklet@. {}".format(response.error)
    return response.SerializeToString()


# Function is called from cpp code.
def get_tasklet_name(impl_path, data):
    impl_type = utils.import_symbol(six.ensure_text(impl_path))
    return six.ensure_binary(impl_type.__holder_cls__.name)


# Function is called from cpp code.
def get_init_description(impl_path, data):
    request = tasklet_pb2.JobStatement()
    request.ParseFromString(data)

    impl_type = utils.import_symbol(six.ensure_text(impl_path))
    impl = impl_type(request)

    temp_input = impl.input.__class__()
    request.input.Unpack(temp_input)
    impl.setup_default_input(impl.input)
    impl.input.MergeFrom(temp_input)

    req_response = tasklet_pb2.Requirements()
    impl.setup_requirements(req_response, impl.input)

    request.input.Pack(impl.input)
    if req_response.SerializeToString():
        request.requirements.CopyFrom(req_response)

    return request.SerializeToString()


def _execute_helper(impl_path, data):
    request = tasklet_pb2.JobInstance()
    request.ParseFromString(data)

    impl_type = utils.import_symbol(impl_path)
    impl = impl_type(request)

    request.statement.input.Unpack(impl.input)

    response = tasklet_pb2.JobResult()

    grpc_pool = LocalGrpcServerPool()
    context = grpc_pool.setup_services(request.statement.ctx)

    if hasattr(impl, "ctx"):
        impl.ctx = cons(impl.ctx, context, request.statement.ctx)

    try:
        impl.run()
        response.success = True
    except Exception:
        setup_python_error(response)

    try:
        response.output.Pack(impl.output)
    except Exception:
        if response.success:
            setup_python_error(response)

    return response


def cons(ctx_decl, context, orig_ctx):
    class Obj(object):
        pass
    o = Obj()
    o.ctx_msg = orig_ctx

    name2field = dict(ctx_decl.DESCRIPTOR.fields_by_name)

    for x in context.entries:
        if x.name not in name2field:
            continue
        f = name2field[x.name]
        if x.any.Is(tasklet_pb2.ServiceRef.DESCRIPTOR):
            service_ref = tasklet_pb2.ServiceRef()
            x.any.Unpack(service_ref)

            value = cons_service_ref(service_ref)
        else:
            value = x
        if f.message_type is not None:
            py_adapter = f.message_type.GetOptions().Extensions[tasklet_pb2.py_adapter]
            if py_adapter:
                py_adapter = utils.import_symbol(py_adapter)
                value = py_adapter(value)
        setattr(o, x.name, value)
    return o


def cons_service_ref(service_ref):
    channel = grpc.insecure_channel(service_ref.address)
    cls = utils.import_symbol(service_ref.client)
    return cls(channel)


def inject(request, into):
    for field in request.DESCRIPTOR.fields:
        x = getattr(request, field.name)

        try:
            value = cons_service_ref(x.ref)
        except AttributeError:  # XXX: fix
            value = x
        if field.message_type is not None:
            py_adapter = field.message_type.GetOptions().Extensions[tasklet_pb2.py_adapter]
            if py_adapter:
                py_adapter = utils.import_symbol(py_adapter)
                value = py_adapter(value)
        setattr(into, field.name, value)
