# coding: utf-8

import json
import logging
import datetime
import aniso8601

from sandbox import sdk2
from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt
import sandbox.common.types.statistics as ctss


class RunType(common.enum.Enum):
    common.enum.Enum.lower_case()
    SIMPLE = None
    PROFILER = None


class MeasureType(common.enum.Enum):
    common.enum.Enum.lower_case()
    STATUS = None
    REQUEST_METRICS = None
    TOTAL = None


class RequestType(common.enum.Enum):
    common.enum.Enum.lower_case()
    CREATION = None
    ENQUEUE = None


class ProfilerType(common.enum.Enum):
    common.enum.Enum.lower_case()
    LEGACY = None
    SERVICEAPI = None


class EmptyTaskletLauncher(sdk2.Task):
    """ Launch and wait empty tasklet, collecting statistics about it stages """

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024
        disk_space = 2

        class Caches(sdk2.Requirements.Caches):
            pass  # no shared caches

    class Parameters(sdk2.Parameters):
        description = "Launch empty tasklet"
        kill_timeout = 180

    def simple_tasklet_request_statistics(self, response_headers, req_type, statistics):
        for metric in common.itertools.chain(
            response_headers[ctm.HTTPHeader.REQ_METRICS].split("|"),
            "request_duration={}".format(response_headers[ctm.HTTPHeader.REQ_DURATION])
        ):
            name, duration = metric.split("=")
            duration = int(float(duration) * 1000)
            now = common.api.DateTime().encode(datetime.datetime.utcnow())
            statistics.append(dict(
                date=now,
                timestamp=now,
                run_type=RunType.SIMPLE,
                measure_type=MeasureType.REQUEST_METRICS,
                request_type=req_type,
                name=name,
                value=duration
            ))

    def profiler_tasklet_request_statistics(self, profile, profiler_type, req_type, statistics):
        for line in profile.split("\n")[5:]:
            metrics = filter(None, line.split(" "))
            name, duration = " ".join(metrics[5:]), int(float(metrics[3]) * 1000)
            if duration <= 10:
                break
            now = common.api.DateTime().encode(datetime.datetime.utcnow())
            statistics.append(dict(
                date=now,
                timestamp=now,
                run_type=RunType.PROFILER,
                measure_type=MeasureType.REQUEST_METRICS,
                request_type=req_type,
                profiler_type=profiler_type,
                name=name,
                value=duration
            ))

    def status_statistics(self, task_id, run_type):
        audit_list = self.server.task[task_id].audit.read()
        statistics = []
        for idx in range(len(audit_list) - 1):
            now = common.api.DateTime().encode(datetime.datetime.utcnow())
            statistics.append(dict(
                date=now,
                timestamp=now,
                run_type=run_type,
                measure_type=MeasureType.STATUS,
                name=audit_list[idx]["status"],
                value=int((
                    aniso8601.parse_datetime(audit_list[idx + 1]["time"]) -
                    aniso8601.parse_datetime(audit_list[idx]["time"])
                ).total_seconds()) * 1000
            ))
        now = common.api.DateTime().encode(datetime.datetime.utcnow())
        statistics.append(dict(
            date=now,
            timestamp=now,
            run_type=run_type,
            measure_type=MeasureType.TOTAL,
            value=int((
              aniso8601.parse_datetime(audit_list[-1]["time"]) -
              aniso8601.parse_datetime(audit_list[0]["time"])
            ).total_seconds()) * 1000
        ))
        logging.info("Task %s of %s run finished. Status signals: %s", task_id, run_type, statistics)
        self.server.statistics[ctss.SignalType.EMPTY_TASKLET_METRICS](statistics)

    def on_execute(self):
        tasklet_binary = sdk2.service_resources.SandboxTasksBinary.find(
            attrs={"tasklet_name": "empty_tasklet", "released": "stable"}
        ).first()

        headers = dict(common.rest.Client._external_auth)
        headers["Content-Type"] = "application/json"

        with self.memoize_stage.run_tasklet_simple:
            statistics = []
            response_headers = common.rest.Client.HEADERS()
            client = self.server >> response_headers
            response = client.task.create(
                type="TASKLET_EMPTY_TASKLET", owner=self.owner, description="Empty tasklet description",
                requirements={"tasks_resource": tasklet_binary.id}, custom_fields=[{"name": "num", "value": 5}],
                children=True
            )
            task_id = response["id"]

            self.simple_tasklet_request_statistics(response_headers, RequestType.CREATION, statistics)

            response_headers = common.rest.Client.HEADERS()
            client = self.server >> response_headers
            client.batch.tasks.start.update([task_id])
            self.simple_tasklet_request_statistics(response_headers, RequestType.ENQUEUE, statistics)

            logging.info("Enqueued simple tasklet. Signals: %s", statistics)
            self.server.statistics[ctss.SignalType.EMPTY_TASKLET_METRICS](statistics)

            self.Context.tasklet_task_id_simple = task_id
            raise sdk2.WaitTask([task_id], ctt.Status.Group.FINISH + ctt.Status.Group.BREAK)

        with self.memoize_stage.check_tasklet_task_simple:
            task_id = self.Context.tasklet_task_id_simple
            self.status_statistics(task_id, RunType.SIMPLE)

        with self.memoize_stage.run_tasklet_profiler:
            statistics = []
            client = self.server.copy() << common.rest.Client.HEADERS({ctm.HTTPHeader.PROFILER: "2"})
            task_id = None
            try:
                client.task.create(
                    type="TASKLET_EMPTY_TASKLET", owner=self.owner, description="Empty tasklet description",
                    requirements={"tasks_resource": tasklet_binary.id}, custom_fields=[{"name": "num", "value": 5}],
                    children=True
                )
            except common.rest.Client.HTTPError as ex:
                ex_json = ex.response.json()
                self.profiler_tasklet_request_statistics(
                    ex_json["profile"], ProfilerType.SERVICEAPI, RequestType.CREATION, statistics
                )
                self.profiler_tasklet_request_statistics(
                    ex_json["legacy_profile"], ProfilerType.LEGACY, RequestType.CREATION, statistics
                )
                task_id = json.loads(ex.response.json()["result"][0])["id"]

            try:
                client.batch.tasks.start.update([task_id])
            except common.rest.Client.HTTPError as ex:
                ex_json = ex.response.json()
                self.profiler_tasklet_request_statistics(
                    ex_json["profile"], ProfilerType.SERVICEAPI, RequestType.ENQUEUE, statistics
                )
                self.profiler_tasklet_request_statistics(
                    ex_json["legacy_profile"], ProfilerType.LEGACY, RequestType.ENQUEUE, statistics
                )
            logging.info("Enqueued profiler tasklet. Signals: %s", statistics)
            self.server.statistics[ctss.SignalType.EMPTY_TASKLET_METRICS](statistics)
            self.Context.tasklet_task_id_profiler = task_id
            raise sdk2.WaitTask([task_id], ctt.Status.Group.FINISH + ctt.Status.Group.BREAK)

        with self.memoize_stage.check_tasklet_task_profiler:
            task_id = self.Context.tasklet_task_id_profiler
            self.status_statistics(task_id, RunType.PROFILER)
