import json
import logging
import os
import sys
import tempfile

from google.protobuf import empty_pb2
from google.protobuf import json_format

from tasklet.api import sched_pb2
from tasklet.api import sched_pb2_grpc
from tasklet.api import spy_pb2

import yt.wrapper as yt_wrapper

from tasklet import runtime
from tasklet.domain import oauth
from tasklet.runtime import dispatch, utils

logger = logging.getLogger(__name__)


os.environ["YT_ALLOW_HTTP_REQUESTS_TO_YT_FROM_JOB"] = "1"


YT_CLIENT = yt_wrapper.YtClient(
    proxy=os.environ.get("YT_CLUSTER_NAME", os.environ.get("YT_PROXY", "hahn")),
    token=oauth.get_token(),
)


def upload_job_instance_spec(yt_client, instance):
    """

    :type yt_client: yt.wrapper.YtClient
    :type instance: tasklet_pb2.JobInstance
    """
    with tempfile.NamedTemporaryFile(mode="w") as f:
        json.dump(json_format.MessageToDict(instance), f)
        f.flush()
        result = yt_client.smart_upload_file(f.name)
    return result


class YtScheduler(sched_pb2_grpc.SchedulerServicer):

    def __init__(self):
        self._job_descriptions = {}
        self._tasklet_to_job = {}
        self.run_id = None

    def Inject(self, request, context):
        runtime.inject(request, self)

        return empty_pb2.Empty()

    def Instance(self, request, context):
        logger.info('Instance')

        if self.run_id is None:
            self.run_id = utils.generate_run_id()

        request.run_id = self.run_id
        request.id = utils.generate_tasklet_id()

        request.statement.MergeFromString(dispatch.get_init_description(request.statement.SerializeToString()))

        self._job_descriptions[request.id] = request

        self.spy.whisper(
            state=spy_pb2.Event.SCHEDULED,
            id=request.id,
            name=request.statement.name,
            parent=request.parent_id,
            run_id=self.run_id,
        )

        return sched_pb2.TaskletId(id=request.id)

    def Schedule(self, request, context):
        if request.id not in self._job_descriptions:
            logger.error('No job description with ID "%s" has been found', request.id)
            return sched_pb2.TaskletId()
        instance = self._job_descriptions.pop(request.id)

        logger.info('Schedule')
        job_meta_yt_path = upload_job_instance_spec(YT_CLIENT, instance)

        tasklet_binary = os.environ.get("TASKLET_BINARY")
        if not tasklet_binary:
            tasklet_binary = YT_CLIENT.smart_upload_file(sys.executable)

        op_spec = {
            "secure_vault": {
                oauth.VAULT_NAME: oauth.get_token(),
            },
            "fail_on_job_restart": True,
        }

        task_spec = {
            "job_count": 1,
            "command": "./tasklet spec.json",
            "environment": {
                "TASKLET_BINARY": tasklet_binary,
                "Y_PYTHON_ENTRY_POINT": "tasklet.domain.yt.main:main",
                "YT_CLUSTER_NAME": os.environ.get("YT_CLUSTER_NAME", os.environ.get("YT_PROXY", "hahn")),
            },
            "file_paths": [
                yt_wrapper.FilePath(tasklet_binary, file_name="tasklet", executable=True),
                yt_wrapper.FilePath(job_meta_yt_path, file_name="spec.json")
            ]
        }

        if instance.statement.requirements.cpu != 0:
            task_spec["cpu_limit"] = instance.statement.requirements.cpu

        if instance.statement.requirements.ram != 0:
            task_spec["memory_limit"] = instance.statement.requirements.ram << 20

        if instance.statement.requirements.tmpfs != 0:
            tmpfs_path = "tmpfs"
            task_spec.update({"tmpfs_path": tmpfs_path, "tmpfs_size": instance.statement.requirements.tmpfs << 20})
            task_spec["environment"]["TMPFS_PATH"] = tmpfs_path

        if instance.statement.requirements.disk != 0:
            op_spec["scheduling_tag_filter"] = "porto"
            task_spec["disk_space_limit"] = instance.statement.requirements.disk << 20

        vanilla_spec = yt_wrapper.VanillaSpecBuilder() \
            .spec(op_spec) \
            .task("task", task_spec)

        op = YT_CLIENT.run_operation(vanilla_spec, sync=False)
        logger.debug("Operation %s started", op.id)
        self._tasklet_to_job[request.id] = str(op.id)

        self.spy.whisper(state=spy_pb2.Event.LAUNCHED, id=request.id, run_id=self.run_id)

        return request

    def GetStatus(self, request_iterator, context):
        for tasklet_id in request_iterator:
            job_id = self._tasklet_to_job[tasklet_id.id]

            state = YT_CLIENT.get_operation_state(job_id)

            job_status = sched_pb2.JobStatus()
            if not state.is_finished():
                job_status.ready = False
                yield job_status
                continue

            job_status.ready = True

            if state.is_unsuccessfully_finished():
                job_status.result.success = False
                job_status.result.error = "YT task failed"

                self.spy.whisper(state=spy_pb2.Event.FAILURE, id=tasklet_id.id, run_id=self.run_id)
            else:
                op = yt_wrapper.Operation(job_id, client=YT_CLIENT)
                stderrs = op.get_jobs_with_error_or_stderr(False)
                assert len(stderrs) == 1

                stderr = stderrs[0]["stderr"]
                for line in stderr.splitlines():
                    record = json.loads(line)
                    if record["type"] == "result":
                        json_format.Parse(record["value"], job_status.result)
                job_status.result.success = True

                self.spy.whisper(state=spy_pb2.Event.SUCCESS, id=tasklet_id.id, run_id=self.run_id)

            yield job_status

    def WaitFor(self, request_iterator, context):
        with yt_wrapper.OperationsTracker() as tracker:
            for tasklet_id in request_iterator:
                job_id = self._tasklet_to_job[tasklet_id.id]

                logger.debug("Waiting for operation {}".format(job_id))
                tracker.add_by_id(job_id, client=YT_CLIENT)

                self.spy.whisper(state=spy_pb2.Event.FINISHED, id=tasklet_id.id, run_id=self.run_id)

        return empty_pb2.Empty()

    def GetContext(self, request, context):
        return sched_pb2.SchedulerContext()
