import logging

from sandbox import sdk2
from sandbox import common

from sandbox.projects.common.nanny import nanny

from sandbox.projects.ads.torch_model_processor import TorchModelProcessor

import sandbox.common.types.task as ctt


class TorchModelTaskScheduler(sdk2.Task):
    class Requirements(sdk2.Requirements):
        cores = 20
        disk_space = 32 * 1024
        ram = 64 * 1024

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 3600

        yt_token = sdk2.parameters.YavSecret("YtToken", description="Yt token name", required=True)
        yt_proxy = sdk2.parameters.String("YtProxy", description="Yt proxy", required=True, default="hahn")
        sandbox_token = sdk2.parameters.YavSecret("SandboxToken", description="Sandbox token name", required=True)

        model_yt_dir = sdk2.parameters.String("ModelYtDir", description="Yt directory with model", required=True)
        artifact_name = sdk2.parameters.String("ArtifactName", description="Artifact name", required=True, default='tsar_processed_model')
        model_id = sdk2.parameters.Integer("ModelID", description="Model ID", required=True)
        release = sdk2.parameters.Bool("Release", description="Release prod or testing", required=True, default=False)
        ttl = sdk2.parameters.Integer("TTL", description="Sandbox resource ttl", required=True, default=30)
        conf = sdk2.parameters.JSON("ModelDescriptorsConfig", description="Model descriptors config", required=True)

        with sdk2.parameters.CheckGroup('Nanny groups to auto release', default=[]) as nanny_labels:
            groups = [
                'default', 'rsya', 'search-offline', 'search-online', 'search-big-base', 'skipper', 'robot',
                'turbo', 'geo', 'geo_robot', 'edadeal', 'edadeal_robot', 'market', 'offer_base', 'offer_base_robot', 'exp', 'exp_robot'
            ]
            nanny_labels.choices = [(s, s) for s in groups]

            release_labels = nanny.LabelsParameter2(
                'Release Labels',
                description='Labels which would be attached to nanny release',
            )

    def get_release_labels(self):
        labels = self.Parameters.release_labels
        labels['groups'] = ','.join(self.Parameters.nanny_labels)
        labels['model_id'] = str(self.Parameters.model_id)
        return labels

    def on_release(self):
        pass

    def on_execute(self):
        with self.memoize_stage.process_torch_model:
            torch_model_task = TorchModelProcessor(
                self,
                kill_timeout=self.Parameters.kill_timeout,
                yt_token=self.Parameters.yt_token,
                sandbox_token=self.Parameters.sandbox_token,
                yt_proxy=self.Parameters.yt_proxy,
                model_yt_dir=self.Parameters.model_yt_dir,
                artifact_name=self.Parameters.artifact_name,
                model_id=self.Parameters.model_id,
                release=self.Parameters.release,
                ttl=self.Parameters.ttl,
                conf=self.Parameters.conf,
                release_labels=self.get_release_labels()
            ).enqueue()
            self.Context.torch_model_task_id = torch_model_task.id
            raise sdk2.WaitTask(torch_model_task.id, list(ctt.Status.Group.FINISH + ctt.Status.Group.BREAK), True)

        model_resource_params = {
            "state": "READY",
            "task_id": self.Context.torch_model_task_id
        }

        resources = list(sdk2.Resource.find(**model_resource_params).order(-sdk2.Resource.id).limit(10))

        logging.info('Found {} resources from task {}'.format(len(resources), self.Context.torch_model_task_id))

        if len(resources) > 1:
            common.rest.Client().release({
                'task_id': self.Context.torch_model_task_id,
                'subject': 'New adv machine torch tsar model. Task id {}'.format(self.Context.torch_model_task_id),
                'type': 'stable',
            })
