# coding: U8
import logging
from sandbox import sdk2
from sandbox import common
from sandbox.common import errors
import sandbox.common.types.task as ctt
from sandbox.projects.voicetech import resource_types
from sandbox.projects.voicetech.tts_server.fastdata.BuildTtsRuFastData import BuildTtsRuFastData
from sandbox.projects.voicetech.tts_server.fastdata.TestTtsRuFastData import TestTtsRuFastData


class BuildAndTestTtsRuFastData(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        disk_space = 1024  # 1 GB

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 30 * 60  # 30 min

        validator_nanny_service = sdk2.parameters.String(
            "TTS Fast Data Validator Nanny service name",
            required=True,
            default="fastdata-validator"
        )

        with sdk2.parameters.Group("Vault"):
            nanny_token_name = sdk2.parameters.String(
                "Nanny token name", required=True)

        with sdk2.parameters.Output:
            # These will be filled in by BuildTtsRuFastData subtask
            fast_data = sdk2.parameters.Resource(
                "TTS RU Fast Data",
                resource_type=resource_types.VOICETECH_TTS_RU_FASTDATA,
            )
            fast_data_bundle = sdk2.parameters.Resource(
                "Fast Data bundle",
                resource_type=resource_types.VOICETECH_TTS_RU_FASTDATA_BUNDLE,
            )

    class Context(sdk2.Task.Context):
        data_name = "ru_fastdata"
        bundle_name = data_name + ".tar.gz"

    def build_fast_data(self):
        task_id = BuildTtsRuFastData(
            self,
            description="Build TTS RU Fast Data",
            owner=self.owner,
            use_external_resources=True,
            external_data_resource=self.Parameters.fast_data,
            external_bundle_resource=self.Parameters.fast_data_bundle
        ).enqueue().id
        return task_id

    def test_fast_data(self):
        build_task = sdk2.Task[self.Context.build_task_id]
        bundle = build_task.Parameters.bundle_resource
        task_id = TestTtsRuFastData(
            self,
            description="Test resource {}".format(bundle.id),
            owner=self.owner,
            fast_data_bundle=bundle,
            validator_nanny_service=self.Parameters.validator_nanny_service,
            nanny_token_name=self.Parameters.nanny_token_name
        ).enqueue().id
        return task_id

    def _delete_broken_fast_data(self):
        self.server.batch.resources["delete"].update(
            id=[self.Parameters.fast_data.id],
            comment="Remove broken fast data"
        )
        self.server.batch.resources["delete"].update(
            id=[self.Parameters.fast_data_bundle.id],
            comment="Remove broken fast data"
        )

    def on_execute(self):
        with self.memoize_stage.create_resources:
            self.Parameters.fast_data = resource_types.VOICETECH_TTS_RU_FASTDATA(
                self,
                "TTS RU Fast Data",
                self.Context.data_name
            )
            self.Parameters.fast_data_bundle = resource_types.VOICETECH_TTS_RU_FASTDATA_BUNDLE(
                self,
                "TTS RU Fast Data Bundle",
                self.Context.bundle_name
            )

        with self.memoize_stage.build_fast_data:
            logging.info("Spawning build task")
            build_task_id = self.build_fast_data()
            self.Context.build_task_id = build_task_id
            raise sdk2.WaitTask(
                build_task_id,
                ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                wait_all=True
            )

        with self.memoize_stage.check_build_status:
            logging.info("Check build task status")
            build_task = sdk2.Task[self.Context.build_task_id]
            if build_task.status not in ctt.Status.Group.SUCCEED:
                self._delete_broken_fast_data()
                raise errors.TaskFailure("Fast data build failed, see task {}".format(build_task.id))

        with self.memoize_stage.test_fast_data:
            logging.info("Spawning test task")
            test_task_id = self.test_fast_data()
            self.Context.test_task_id = test_task_id
            raise sdk2.WaitTask(
                test_task_id,
                ctt.Status.Group.FINISH | ctt.Status.Group.BREAK,
                wait_all=True
            )

        with self.memoize_stage.check_test_status:
            logging.info("Check test status")
            test_task = sdk2.Task[self.Context.test_task_id]
            if test_task.status not in ctt.Status.Group.SUCCEED:
                self._delete_broken_fast_data()
                build_task = sdk2.Task[self.Context.build_task_id]
                raise errors.TaskFailure("Fast data test failed, see task {}".format(
                    self.Context.test_task_id
                ))

