import logging

from sandbox import sdk2
from sandbox import common

from sandbox.projects.common.nanny import nanny

from sandbox.projects.adv_machine.common.parameters import YTParameters

from sandbox.projects.ads.adv_machine_tsar_dumps_collector import AdvMachineTsarDumpsCollector

import sandbox.common.types.task as ctt


class AdvMachineTorchDumpListScheduler(sdk2.Task):
    class Requirements(sdk2.Requirements):
        cores = 20
        disk_space = 32 * 1024
        ram = 64 * 1024

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 3600

        yt_params = YTParameters

        sandbox_models_info = sdk2.parameters.Dict("Dict of pairs <model_id:sandbox_scheduler_id>")
        banner_dumps_dir_name = sdk2.parameters.String("Name of directory with banner dumps", default_value="torch_banner_models")
        user_dumps_dir_name = sdk2.parameters.String("Name of directory with user dumps", default_value="torch_user_models")
        user_model_resource_name = sdk2.parameters.String("User model resource name", default_value="TORCH_TSAR_USER_MODEL")
        banner_model_resource_name = sdk2.parameters.String("Banner model resource name", default_value="TORCH_TSAR_BANNER_MODEL")
        tensor_model_resource_name = sdk2.parameters.String("Tensor model resource name", default_value="TORCH_TSAR_TENSOR_MODEL")
        torch_processor_task_name = sdk2.parameters.String("Torch model processor task name", default_value="TORCH_MODEL_PROCESSOR")
        yt_models_version_table = sdk2.parameters.String("YT table with last produced and released versions of models",
                                                         default_value="//home/advquality/adv_machine/torch_models_info/prod_released_tasks_info")
        release = sdk2.parameters.Bool("Release", description="Release prod or testing", required=True, default=False)
        ttl = sdk2.parameters.Integer("TTL", description="Sandbox resource ttl", required=True, default=30)

        with sdk2.parameters.CheckGroup('Nanny groups to auto release', default=[]) as nanny_labels:
            groups = [
                'default', 'rsya', 'search-offline', 'search-online', 'search-big-base', 'skipper', 'robot',
                'turbo', 'geo', 'geo_robot', 'edadeal', 'edadeal_robot', 'market', 'offer_base', 'offer_base_robot', 'exp', 'exp_robot'
            ]
            nanny_labels.choices = [(s, s) for s in groups]

            release_labels = nanny.LabelsParameter2(
                'Release Labels',
                description='Labels which would be attached to nanny release',
            )

    def get_release_labels(self):
        labels = self.Parameters.release_labels
        labels['groups'] = ','.join(self.Parameters.nanny_labels)
        return labels

    def on_release(self):
        pass

    def on_execute(self):
        with self.memoize_stage.process_torch_models:
            torch_models_task = AdvMachineTsarDumpsCollector(
                self,
                kill_timeout=self.Parameters.kill_timeout,
                yt_params__proxy=self.Parameters.yt_params.proxy,
                yt_params__prefix=self.Parameters.yt_params.prefix,
                yt_params__token_vault=self.Parameters.yt_params.token_vault,
                yt_params__pool=self.Parameters.yt_params.pool,
                sandbox_models_info=self.Parameters.sandbox_models_info,
                banner_dumps_dir_name=self.Parameters.banner_dumps_dir_name,
                user_dumps_dir_name=self.Parameters.user_dumps_dir_name,
                user_model_resource_name=self.Parameters.user_model_resource_name,
                banner_model_resource_name=self.Parameters.banner_model_resource_name,
                tensor_model_resource_name=self.Parameters.tensor_model_resource_name,
                torch_processor_task_name=self.Parameters.torch_processor_task_name,
                yt_models_version_table=self.Parameters.yt_models_version_table,
                release=self.Parameters.release,
                ttl=self.Parameters.ttl,
                release_labels=self.get_release_labels()
            ).enqueue()
            self.Context.torch_models_task_id = torch_models_task.id
            raise sdk2.WaitTask(torch_models_task.id, list(ctt.Status.Group.FINISH + ctt.Status.Group.BREAK), True)

        model_resource_params = {
            "state": "READY",
            "task_id": self.Context.torch_models_task_id
        }

        resources = list(sdk2.Resource.find(**model_resource_params).order(-sdk2.Resource.id).limit(10))

        logging.info('Found {} resources from task {}'.format(len(resources), self.Context.torch_model_task_id))

        if len(resources) > 1:
            common.rest.Client().release({
                'task_id': self.Context.torch_models_task_id,
                'subject': 'ADV_MACHINE_TSAR_DUMP_LIST_COLLECTOR. Task id {}'.format(self.Context.torch_models_task_id),
                'type': 'stable',
            })
