# -*- coding: utf-8 -*-
import logging
import os

import subprocess as sp

import sandbox.sandboxsdk.environments as sdk_environments

from sandbox import sdk2

import sandbox.common.types.resource as ctr

from sandbox.projects.ads.tsar.lib import get_resource_path

from sandbox.projects.adv_machine.common import ConfigFormat, upload_file_to_yt, get_yt_config
from sandbox.projects.adv_machine.common.parameters import YTParameters

logger = logging.getLogger(__name__)


class UploadTsarPytorchToYt(sdk2.Task):
    """Collects dumps from several task_ids and pushes it into //home/advquality/tsar/tsar_models/pytorch/dump"""
    class Requirements(sdk2.Task.Requirements):
        cores = 1

        environments = [
            sdk_environments.PipEnvironment('yandex-yt'),
            sdk_environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
        ]

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        yt_params = YTParameters
        checkpoint_attribute_name = sdk2.parameters.String("Checkpoint (resource_id) attribute name",
                                                           default_value="@sandbox_resource_id")
        model_id = sdk2.parameters.Integer("Model Id of dump in sandbox (defines different types of models)",
                                                           default_value=1)
        dump_maker = sdk2.parameters.Resource(
            'Dump Maker',
            state=(ctr.State.READY, ),
            required=True,
        )
        yt_path = sdk2.parameters.String("Dst Yt Path", default_value="//home/advquality/tsar/tsar_models/pytorch/dump")
        unprepared_torch_dump = sdk2.parameters.Resource(
            'Torch Dump Unprepared',
            state=(ctr.State.READY, ),
            required=True
        )

    def _upload_pytorch_dump_on_yt(self, yt, yt_client, yt_path, unprepared_dump_resource, attribute_name, dump_maker_bin, model_id):
        unprepared_dump_path = get_resource_path(unprepared_dump_resource)

        assert os.path.isdir(unprepared_dump_path), 'Pytorch dump expected to be a directory'

        sp.check_call(
            '{} make-pytorch-dump --model-dir {} -o ./dump -i {}'.format(os.path.join('./', dump_maker_bin), unprepared_dump_path, model_id),
            shell=True
        )
        upload_file_to_yt(yt_client, yt_path, './dump')
        yt_client.set(yt.ypath_join(yt_path, attribute_name), str(unprepared_dump_resource.id))

    def _upload_dumps_on_yt(self):
        adv_yt_config = get_yt_config(self.Parameters.yt_params, self.author, format=ConfigFormat.CLIENT)

        import yt.wrapper as yt
        yt_client = yt.YtClient(config=adv_yt_config)

        self._upload_pytorch_dump_on_yt(
            yt=yt,
            yt_client=yt_client,
            yt_path=self.Parameters.yt_path,
            attribute_name=self.Parameters.checkpoint_attribute_name,
            dump_maker_bin=get_resource_path(self.Parameters.dump_maker),
            unprepared_dump_resource=self.Parameters.unprepared_torch_dump,
            model_id=self.Parameters.model_id
        )

    def on_execute(self):
        self._upload_dumps_on_yt()
