import logging
import threading

from google.protobuf import empty_pb2

from tasklet.api import sched_pb2
from tasklet.api import sched_pb2_grpc
from tasklet.api import spy_pb2

from tasklet import runtime
from tasklet.runtime import dispatch
from tasklet.runtime import utils


logger = logging.getLogger(__name__)


class LocalScheduler(sched_pb2_grpc.SchedulerServicer):
    def __init__(self):
        self._job_descriptions = {}
        self._threads = {}
        self.run_id = None
        self.spy = None

    def Inject(self, request, context):
        runtime.inject(request, self)

        return empty_pb2.Empty()

    def Instance(self, request, context):
        if self.run_id is None:
            self.run_id = utils.generate_run_id()

        request.run_id = self.run_id
        request.id = utils.generate_tasklet_id()

        request.statement.MergeFromString(dispatch.get_init_description(request.statement.SerializeToString()))

        self._job_descriptions[request.id] = request

        holder = dispatch.get_holder(request.statement.name)
        metadata = utils.message_repr(holder.Input(), request.statement.input)

        self.spy.whisper(
            state=spy_pb2.Event.SCHEDULED,
            id=request.id,
            run_id=request.run_id,
            name=request.statement.name,
            parent=request.parent_id,
            metadata=metadata,
        )

        return sched_pb2.TaskletId(id=request.id)

    def GetContext(self, request, context):
        return sched_pb2.SchedulerContext()

    def Schedule(self, request, context):
        if request.id not in self._job_descriptions:
            logger.error('No job description with ID "%s" has been found', request.id)
            return sched_pb2.TaskletId()
        job = self._job_descriptions.pop(request.id)

        def run():
            self.spy.whisper(state=spy_pb2.Event.LAUNCHED, id=job.id, run_id=self.run_id)

            res = self.executor.execute(job)
            thread = threading.current_thread()
            thread.status.result.CopyFrom(res)
            thread.status.ready = True

            holder = dispatch.get_holder(job.statement.name)
            metadata = utils.message_repr(holder.Output(), thread.status.result.output)

            self.spy.whisper(
                state=spy_pb2.Event.SUCCESS if res.success else spy_pb2.Event.FAILURE,
                id=job.id,
                run_id=self.run_id,
                metadata=metadata,
            )

        t = threading.Thread(target=run)
        t.status = sched_pb2.JobStatus()
        t.start()

        self._threads[request.id] = t

        return request

    def GetStatus(self, request_iterator, context):
        for request in request_iterator:
            status = sched_pb2.JobStatus()
            if request.id in self._threads:
                status.CopyFrom(self._threads[request.id].status)
            else:
                logger.error('No running job with the ID "%s" has been found', request.id)
                status.ready = False
            yield status

    def WaitFor(self, request_iterator, context):
        for request in request_iterator:
            if request.id in self._threads:
                self._threads[request.id].join()
            else:
                logger.error('No running job with the ID "%s" has been found', request.id)
        return empty_pb2.Empty()
