# coding: utf-8
import io
import cProfile
import pstats

from sandbox import common, sdk2

HTTP_METHODS = ("GET", "POST", "PUT", "DELETE")
SORT_FIELDS = tuple(pstats.Stats.sort_arg_dict_default)


class Profile(Exception):

    def __repr__(self):
        return self.args[0]

    __str__ = __repr__


class ProfileSandboxEndpoint(sdk2.Task):
    """
    Profile any legacy-server endpoint from on_enqueue. For non-binary runs only.
    """

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 2048
        disk_space = 15

        class Caches(sdk2.Requirements.Caches):
            pass  # no shared caches

    class Parameters(sdk2.Parameters):
        description = "Profile any legacy-server endpoint"
        kill_timeout = 180

        method = sdk2.parameters.String(
            "HTTP method", default="GET", choices=tuple((x, x) for x in HTTP_METHODS), required=True
        )
        path = sdk2.parameters.String("Path", default="", required=True)
        params = sdk2.parameters.Dict("Query Parameters", required=False)
        data = sdk2.parameters.JSON("Data", required=False)
        sortby = sdk2.parameters.String("Sort By", choices=tuple((x, x) for x in SORT_FIELDS), required=True)

    def on_enqueue(self):

        from sandbox.yasandbox.controller import dispatch

        pr = cProfile.Profile()

        client_method = {
            "GET": "read",
            "POST": "create",
            "PUT": "update",
            "DELETE": "delete",
        }[self.Parameters.method]

        kwargs = {
            "path": self.Parameters.path,
            "params": self.Parameters.params or {},
            "input_mode": common.rest.Client.JSON(),
            "output_mode": common.rest.Client.JSON(),
        }

        if self.Parameters.method != "GET":
            kwargs["data"] = self.Parameters.data

        with common.rest.DispatchedClient as dsp:
            dsp(dispatch.RestClient(self.id, self.author, jailed=False))
            pr.enable()
            getattr(self.server, client_method)(**kwargs)
            pr.disable()

        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats(self.Parameters.sortby)
        ps.print_stats()

        raise Profile(s.getvalue())
