import tempfile
from pathlib import Path

import faulthandler

import load.projects.cloud.loadtesting.server.admin.proto.profiling_pb2 as profiling_pb2
import load.projects.cloud.loadtesting.server.admin.proto.profiling_pb2_grpc as profiling_pb2_grpc
from load.projects.cloud.loadtesting import events
from load.projects.cloud.loadtesting.server.admin.implementations.profiling import EnableProfiling


class Profiling(profiling_pb2_grpc.ProfilingServicer):
    def __init__(self):
        self._pstats_path = Path('/tmp/loadtesting/yappi_output.pstats')
        self._profiling = events.EventWithAutostop(
            EnableProfiling(self._pstats_path)
        )

    def _response(self, state: events.EventWithAutostop.State):
        return profiling_pb2.ProfilingResponse(
            status='RUNNING' if state.is_in_progress else 'NOT RUNNING',
            stop_in=state.stop_in,
            error=str(state.error),
            pstats_file=str(self._pstats_path),
        )

    def Start(self, request, context):
        """
        grpcurl -plaintext  -d '{"duration": 120}'  localhost:50055 Profiling/Start
        """
        return self._response(self._profiling.with_stop_after(request.duration or None).start())

    def State(self, request, response):
        """
            grpcurl -plaintext  localhost:50055 Profiling/State
        """
        return self._response(self._profiling.state())

    def Stop(self, request, response):
        """
            grpcurl -plaintext  localhost:50055 Profiling/Stop
        """
        return self._response(self._profiling.stop())

    def DumpThreads(self, reqeust, response):
        """
            grpcurl -plaintext  localhost:50055 Profiling/DumpThreads
        """
        with tempfile.TemporaryFile() as stream:
            faulthandler.dump_traceback(stream)
            stream.seek(0)
            return profiling_pb2.String(
                str=stream.read()
            )
