import collections
import json
import os
from sandbox import sdk2
from sandbox.sdk2.helpers import subprocess

from sandbox.projects.common.nanny.client import NannyClient
from sandbox.projects.common.nanny import nanny
from sandbox.projects.common.nanny import const


class TorchModelProcessor(nanny.ReleaseToNannyTask2, sdk2.Task):
    class Requirements(sdk2.Requirements):
        cores = 20
        disk_space = 32 * 1024
        ram = 64 * 1024

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 3600

        yt_token = sdk2.parameters.YavSecret("YtToken", description="Yt token name", required=True)
        sandbox_token = sdk2.parameters.YavSecret("SandboxToken", description="Sandbox token name", required=True)
        yt_proxy = sdk2.parameters.String("YtProxy", description="Yt proxy", required=True, default="hahn")

        model_yt_dir = sdk2.parameters.String("ModelYtDir", description="Yt directory with model", required=True)
        artifact_name = sdk2.parameters.String("ArtifactName", description="Artifact name", required=True, default='tsar_processed_model')
        model_id = sdk2.parameters.Integer("ModelID", description="Model ID", required=True)
        release = sdk2.parameters.Bool("Release", description="Release prod or testing", required=True, default=False)
        ttl = sdk2.parameters.Integer("TTL", description="Sandbox resource ttl", required=True, default=30)
        conf = sdk2.parameters.JSON("ModelDescriptorsConfig", description="Model descriptors config", required=True)

        release_labels = nanny.LabelsParameter2(
            'Release Labels',
            description='Labels which would be attached to nanny release',
        )

    def on_execute(self):
        binary = sdk2.ResourceData(
            sdk2.Resource.find(
                type="TORCH_TSAR_PROCESSOR_BINARY",
                status="READY",
                attrs={"released": "testing"}
            ).order(-sdk2.Resource.id).first()
        ).path

        conf_path = 'model_process_descriptors_config.json'
        json.dump(self.Parameters.conf, open(conf_path, 'w'))

        env = os.environ.copy()
        env.update({
            'YT_TOKEN': self.Parameters.yt_token.data()['token'],
            'YT_PROXY': self.Parameters.yt_proxy,
            'SANDBOX_TOKEN': self.Parameters.sandbox_token.data()['token'],
            'SYNCHROPHAZOTRON_PATH': str(self.synchrophazotron)
        })

        cmd = '{binary} --model_yt_dir {model_yt_dir} --artifact_name {artifact_name} {release} --sandbox --ttl {ttl} --model_id {model_id} --model_descriptors_config {conf}'.format(
            binary=binary,
            model_yt_dir=self.Parameters.model_yt_dir,
            artifact_name=self.Parameters.artifact_name,
            release='--release' if self.Parameters.release else '',
            ttl=self.Parameters.ttl,
            model_id=self.Parameters.model_id,
            conf=conf_path
        )

        with sdk2.helpers.ProcessLog(self, logger="torch_model_processor") as pl:
            subprocess.check_call(
                cmd,
                shell=True,
                stdout=pl.stdout,
                stderr=subprocess.STDOUT,
                env=env
            )

    def on_release(self, params):
        nanny.ReleaseToNannyTask2.on_release(self, params)
        nanny_client = NannyClient(
            api_url=const.NANNY_API_URL,
            oauth_token=sdk2.Vault.data('adv_machine_nanny_token'),
        )
        data = collections.defaultdict(dict, self.get_nanny_release_info(params))
        data['meta']['labels'] = self.Parameters.release_labels
        data['model']['id'] = int(self.Parameters.model_id)
        result = nanny_client.create_release2(dict(data))['value']
        release_id = result['id']
        release_link = const.RELEASE_REQUEST_TEMPLATE.format(
            nanny_api_url=const.NANNY_API_URL,
            release_request_id=release_id,
        )
        self.set_info(
            'Nanny release request <a href="{}">{}</a> was created.'.format(release_link, release_id),
            do_escape=False
        )
