import base64
import enum
import logging
import time

from google.protobuf import empty_pb2
from google.protobuf import json_format

from library.python import func

from sandbox import sdk2
from sandbox.common.types import misc as ctm
from sandbox.common.types import task as ctt

from tasklet import runtime
from tasklet.api import sched_pb2
from tasklet.api import sched_pb2_grpc
from tasklet.api import spy_pb2
from tasklet.domain import oauth
from tasklet.runtime import dispatch
from tasklet.runtime import utils

from . import task as sb_task
from . import utils as sb_utils

logger = logging.getLogger(__name__)


@func.memoize()
def sandbox_task_cls(tasklet_id):
    """
    Returns sandbox class by tasklet id.
    :rtype: tasklet.domain.sandbox.BaseTasklet
    """
    impl_cls = utils.import_symbol(dispatch.impl_class_path(tasklet_id))
    return impl_cls.__holder_cls__.sbtask_cls


class ProtoFormat(enum.Enum):
    BINARY = "binary"  # base64-encoded
    JSON = "json"


class SandboxScheduler(sched_pb2_grpc.SchedulerServicer):

    def __new__(cls, **kwargs):
        if cls is not SandboxScheduler:
            return super(SandboxScheduler, cls).__new__(cls)

        return SandboxOutsideScheduler(**kwargs) if sdk2.Task.current is None else SandboxInsideScheduler(**kwargs)

    def __init__(self, proto_format=ProtoFormat.JSON, use_glycine=True):
        self._proto_format = proto_format
        self._use_glycine = use_glycine
        self._tasklet_to_job = {}
        self._job_to_tasklet = {}
        self._run_id = None

    @func.lazy_property
    def _client(self):
        return oauth.sandbox_client()

    @func.lazy_property
    def _owner(self):
        return sb_utils.sandbox_owner()

    @func.lazy_property
    def _tasks_resource_id(self):
        return sb_utils.tasks_binary_resource(self._owner, not self._use_glycine)

    def _task_custom_parameters(self, job):
        if self._use_glycine:
            return {
                "proto_format": self._proto_format.value,
                "request": (
                    json_format.MessageToJson(job)
                    if self._proto_format == ProtoFormat.JSON else
                    base64.b64encode(job.SerializeToString())
                ),
            }

        sb_task_cls = sandbox_task_cls(job.statement.name)
        params_message = sb_task_cls.__holder_cls__.Input()
        job.statement.input.Unpack(params_message)

        tasklet_input_object = json_format.MessageToDict(params_message, preserving_proto_field_name=True)
        return {
            "__tasklet_input__": tasklet_input_object,
            "__tasklet_id__": job.id,
            "__tasklet_name__": job.statement.name,
            "__run_id__": job.run_id,
        }

    def _task_requirements(self, job):
        requirements = {
            "tasks_resource": self._tasks_resource_id,
            "caches": {},  # no vcs cache will be available
        }

        if job.statement.requirements.tmpfs != 0:
            requirements["ramdrive"] = {
                "type": "tmpfs",
                "size": job.statement.requirements.tmpfs << 20,
            }

        if job.statement.requirements.ram != 0:
            requirements["ram"] = job.statement.requirements.ram << 20

        if job.statement.requirements.disk != 0:
            requirements["disk_space"] = job.statement.requirements.disk << 20

        if job.statement.requirements.cpu != 0:
            requirements["cores"] = job.statement.requirements.cpu

        for field in job.statement.requirements.sandbox.DESCRIPTOR.fields:
            if getattr(job.statement.requirements.sandbox, field.name):
                requirements[field.name] = getattr(job.statement.requirements.sandbox, field.name)
        return requirements

    def _task_creation_parameters(self, job):
        """
        :param tasklet_pb2.JobInstance job:
        :rtype: dict
        """
        params = {
            "description": job.statement.name,
            "tags": ["TASKLET", job.statement.name],
            "requirements": self._task_requirements(job),
            "owner": self._owner,
            "priority": {"class": "SERVICE", "subclass": "HIGH"},
        }

        if job.statement.requirements.ttl != 0:
            params["kill_timeout"] = job.statement.requirements.ttl

        if self._use_glycine:
            params["type"] = "GLYCINE_2"
        else:
            sbtask_cls = sandbox_task_cls(job.statement.name)
            params["type"] = sbtask_cls.name

        custom_params = self._task_custom_parameters(job)
        params["custom_fields"] = [{"name": name, "value": value} for name, value in custom_params.items()]
        return params

    def Instance(self, request, context):
        request.statement.MergeFromString(dispatch.get_init_description(request.statement.SerializeToString()))
        request.id = utils.generate_tasklet_id()

        if self._run_id is None:
            self._run_id = utils.generate_run_id()

        request.run_id = self._run_id

        params = self._task_creation_parameters(request)
        task_info = self._client.task(params)

        task_id = task_info["id"]
        self._tasklet_to_job[request.id] = task_id
        self._job_to_tasklet[task_id] = request.id
        logger.info("Sandbox task created for tasklet '%s': %s", request.id, sb_utils.task_link(task_id))

        self.spy.whisper(
            state=spy_pb2.Event.SCHEDULED,
            id=request.id,
            name=request.statement.name,
            parent=request.parent_id,
            run_id=self._run_id,
        )

        return sched_pb2.TaskletId(id=request.id)

    @staticmethod
    def _is_finished(task_status):
        return task_status in ctt.Status.Group.FINISH + ctt.Status.Group.BREAK

    def GetStatus(self, request_iterator, context):
        ids = [self._tasklet_to_job[request.id] for request in request_iterator]
        tasks = self._client.task.read(
            id=ids,
            fields=["status", "output_parameters"],
            limit=len(ids),
        )

        for task in tasks["items"]:
            job_status = sched_pb2.JobStatus()
            job_status.ready = self._is_finished(task["status"])
            output_parameters = task.get("output_parameters", {})
            if self._use_glycine:
                response = output_parameters.get("response")
                if response is not None:
                    if self._proto_format == ProtoFormat.JSON:
                        json_format.Parse(response, job_status.result)
                    else:
                        job_status.result.ParseFromString(base64.b64decode(response))
            else:
                tasklet_result = output_parameters.get(sb_task.BaseTasklet.Parameters.tasklet_result.name)
                if tasklet_result is not None:
                    json_format.Parse(tasklet_result, job_status.result)
            self.spy.whisper(
                state=spy_pb2.Event.SUCCESS if job_status.result.success else spy_pb2.Event.FAILURE,
                id=self._job_to_tasklet[task.get("id")],
                run_id=self._run_id
            )
            yield job_status

    def Inject(self, request, context):
        runtime.inject(request, self)

        return empty_pb2.Empty()

    def GetContext(self, request, context):
        return sched_pb2.SchedulerContext()


class SandboxTaskletScheduler(SandboxScheduler):

    def __new__(cls, **kwargs):
        kwargs["use_glycine"] = False
        return SandboxScheduler(**kwargs)


class SandboxOutsideScheduler(SandboxScheduler):

    def __init__(self, *args, **kwargs):
        super(SandboxOutsideScheduler, self).__init__(*args, **kwargs)
        if self._use_glycine:
            oauth.create_or_update_oauth_token()

    def Schedule(self, request, context):
        task_id = self._tasklet_to_job[request.id]
        result = self._client.batch.tasks.start.update(task_id)[0]

        if result["status"] == ctm.BatchResultStatus.ERROR:
            raise Exception("Failed to create task: {}".format(result["status"]))
        elif result["status"] == ctm.BatchResultStatus.WARNING:
            logger.debug("Warning while starting task: %s", result["message"])

        logger.debug("Task %s started", task_id)
        return request

    def WaitFor(self, request_iterator, context):
        def wait(request):
            logger.debug("Waiting for task {} in tasklet {}".format(self._tasklet_to_job[request.id], request.id))

            while True:
                task_status = self._client.task[self._tasklet_to_job[request.id]].read()["status"]
                if not self._is_finished(task_status):
                    time.sleep(3)
                    continue
                if task_status == ctt.Status.SUCCESS:
                    break
                raise Exception("The status of task {} in tasklet {} is {}".format(
                    self._tasklet_to_job[request.id],
                    request.id,
                    task_status,
                ))

        for request in request_iterator:
            wait(request)

        return empty_pb2.Empty()


class SandboxInsideScheduler(SandboxScheduler):

    def __init__(self, *args, **kwargs):
        super(SandboxInsideScheduler, self).__init__(*args, **kwargs)
        self._run_id = sdk2.Task.current.Parameters.__run_id__

    @func.lazy_property
    def _client(self):
        return sdk2.Task.current.server

    @func.lazy_property
    def _owner(self):
        return sdk2.Task.current.owner

    @func.lazy_property
    def _tasks_resource_id(self):
        return sdk2.Task.current.Requirements.tasks_resource.id

    def _task_custom_parameters(self, job):
        custom_parameters = super(SandboxInsideScheduler, self)._task_custom_parameters(job)
        if not self._use_glycine and sdk2.Task.current.Parameters.__tasklet_secret__:
            custom_parameters["__tasklet_secret__"] = str(sdk2.Task.current.Parameters.__tasklet_secret__)
        return custom_parameters

    def _task_creation_parameters(self, job):
        params = super(SandboxInsideScheduler, self)._task_creation_parameters(job)
        params["children"] = True  # Mark current task as a parent for created one.
        return params

    def Schedule(self, request, context):
        for t in sdk2.Task.current.find(id=self._tasklet_to_job[request.id]):
            t.enqueue()
            self.spy.whisper(state=spy_pb2.Event.LAUNCHED, id=request.id, run_id=self._run_id)

        return request

    def WaitFor(self, request_iterator, context):
        task = sdk2.Task.current
        refs = [self._tasklet_to_job[x.id] for x in request_iterator]
        sdk2.WaitTask(refs, ctt.Status.Group.FINISH + ctt.Status.Group.BREAK, wait_all=False)(task)
        for request in request_iterator:
            self.spy.whisper(state=spy_pb2.Event.FINISHED, id=request.id, run_id=self._run_id)
        return empty_pb2.Empty()
