# -*- coding: utf-8 -*-
import logging
import os
import shutil

from collections import defaultdict

from sandbox import sdk2

import sandbox.sandboxsdk.environments as sdk_environments

from sandbox.projects.adv_machine.common import ConfigFormat, get_yt_config
from sandbox.projects.adv_machine.common.parameters import YTParameters

from sandbox.projects.common.nanny.client import NannyClient
from sandbox.projects.common.nanny import nanny
from sandbox.projects.common.nanny import const

from sandbox.projects.adv_machine.common import resources  # noqa

logger = logging.getLogger(__name__)


class AdvMachineTsarDumpsCollector(nanny.ReleaseToNannyTask2, sdk2.Task):
    """Collects torch dumps from several model_ids for adv_machine"""
    class Requirements(sdk2.Task.Requirements):
        cores = 1

        environments = [
            sdk_environments.PipEnvironment('yandex-yt'),
            sdk_environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
        ]

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        yt_params = YTParameters
        sandbox_models_info = sdk2.parameters.Dict("Dict of pairs <model_id:sandbox_scheduler_id>")
        banner_dumps_dir_name = sdk2.parameters.String("Name of directory with banner dumps", default_value="torch_banner_models")
        user_dumps_dir_name = sdk2.parameters.String("Name of directory with user dumps", default_value="torch_user_models")
        user_model_resource_name = sdk2.parameters.String("User model resource name", default_value="TORCH_TSAR_USER_MODEL")
        banner_model_resource_name = sdk2.parameters.String("Banner model resource name", default_value="TORCH_TSAR_BANNER_MODEL")
        tensor_model_resource_name = sdk2.parameters.String("Tensor model resource name", default_value="TORCH_TSAR_TENSOR_MODEL")
        torch_processor_task_name = sdk2.parameters.String("Torch model processor task name", default_value="TORCH_MODEL_PROCESSOR")
        yt_models_version_table = sdk2.parameters.String("YT table with last produced and released versions of models",
                                                         default_value="//home/advquality/adv_machine/torch_models_info/prod_released_tasks_info")
        release = sdk2.parameters.Bool("Release", description="Release prod or testing", required=True, default=False)
        ttl = sdk2.parameters.Integer("TTL", description="Sandbox resource ttl", required=True, default=30)
        release_labels = nanny.LabelsParameter2(
            'Release Labels',
            description='Labels which would be attached to nanny release',
        )

    def _get_task_resource_by_params(self, resource_type_str, task_id, model_id):
        params = {
            'type': str(resource_type_str),
            'task_id': task_id,
            'state': 'READY',
            'attrs': {
                'released': 'stable',
                'model_id': model_id
            }
        }
        resource = sdk2.Resource.find(**params).order(-sdk2.Resource.id).first()
        return resource

    def _get_last_tensor_resource(self, model_id):
        params = {
            'type': str(self.Parameters.banner_model_resource_name),
            'task_type': str(self.Parameters.torch_processor_task_name),
            'state': 'READY',
            'attrs': {
                'released': 'stable',
                'model_id': model_id
            }
        }
        resource = sdk2.Resource.find(**params).order(-sdk2.Resource.id).first()
        return resource

    def _get_task_by_id(self, task_id):
        task = sdk2.Task.find(
            id=task_id
        ).order(-sdk2.Task.id).first()
        return task

    def _get_last_released_tasks_info(self, yt):
        result = defaultdict(int)
        if not yt.exists(self.Parameters.yt_models_version_table):
            return result
        for model_info in yt.read_table(self.Parameters.yt_models_version_table):
            result[int(model_info['model_id'])] = int(model_info['last_released_sb_task_id'])
        return result

    def _is_new_model_task(self, models_info, model_id, task_id):
        model_id = int(model_id)
        return model_id not in models_info or models_info[model_id] != int(task_id)

    def _write_tasks_info_2_yt(self, tasks_info, yt):
        data = []
        for k, v in tasks_info.items():
            data.append({
                'model_id': int(k),
                'last_released_sb_task_id': int(v)
            })
        yt.write_table(self.Parameters.yt_models_version_table, data, format='json')

    def _get_resource_data_path(self, resource):
        resource_data = sdk2.ResourceData(resource)
        return str(resource_data.path)

    def _write_sb_task_id_2_file(self, task_id, path):
        with open(path, 'w') as f:
            f.write(str(task_id))

    def on_execute(self):
        adv_yt_config = get_yt_config(self.Parameters.yt_params, self.author, format=ConfigFormat.CLIENT)
        import yt.wrapper as yt

        yt_client = yt.YtClient(config=adv_yt_config)

        all_model_ids = set(map(int, self.Parameters.sandbox_models_info.keys()))

        if len(all_model_ids) == 0:
            logger.info("No models info provided - nothing to do")
            return

        currently_released_models_info = self._get_last_released_tasks_info(yt=yt_client)

        updated_models = {}
        sb_models_info = {}

        for model_id, sandbox_scheduler_id in self.Parameters.sandbox_models_info.iteritems():
            tensor_resource = self._get_last_tensor_resource(int(model_id))

            task = self._get_task_by_id(tensor_resource.task_id)

            if self._is_new_model_task(currently_released_models_info, int(model_id), task.id):
                updated_models[str(model_id)] = str(task.id)
            sb_models_info[int(model_id)] = int(task.id)

        if len(updated_models) == 0:
            logging.info('No models updated since lust run. Finish.')
            return
        else:
            debug_str = ', '.join(['({}: {})'.format(k, v) for k, v in updated_models.items()])
            logging.info('Models updates: ' + debug_str)

        for dirname in [self.Parameters.banner_dumps_dir_name, self.Parameters.user_dumps_dir_name]:
            if os.path.isdir(dirname):
                shutil.rmtree(dirname)

        all_model_ids_str = ','.join(map(str, sorted(all_model_ids)))

        banner_dumps_list_resource = resources.AdvMachineBannerTsarDumpsList(self, "AdvMachine torch {} banner parts".format(all_model_ids_str), self.Parameters.banner_dumps_dir_name, ttl=30)
        user_dumps_list_resource = resources.AdvMachineUserTsarDumpsList(self, "AdvMachine torch {} user parts".format(all_model_ids_str), self.Parameters.user_dumps_dir_name, ttl=30)

        dst_resource_info = {}

        for entity, resource in [('banner', banner_dumps_list_resource), ('user', user_dumps_list_resource)]:
            entity_resource_data = sdk2.ResourceData(resource)
            entity_resource_data.path.mkdir(0o755, parents=True, exist_ok=True)
            entity_models_root = os.path.join(str(entity_resource_data.path), 'models')
            dst_resource_info[entity] = {
                'root': entity_models_root,
                'resource_data': sdk2.ResourceData(resource)
            }
            if os.path.isdir(entity_models_root):
                shutil.rmtree(entity_models_root)
            os.makedirs(entity_models_root)

        for model_id, task_id in sb_models_info.items():
            # banner stuff
            banner_resource = self._get_task_resource_by_params(self.Parameters.banner_model_resource_name, task_id, int(model_id))

            dst_banner_data_dir = os.path.join(dst_resource_info['banner']['root'], str(model_id))
            dst_banner_data_path = os.path.join(dst_banner_data_dir, 'BannerNamespaces')
            os.makedirs(dst_banner_data_dir)

            self._write_sb_task_id_2_file(task_id, os.path.join(dst_banner_data_dir, 'sandbox_task_id'))

            src_banner_data_path = self._get_resource_data_path(banner_resource)

            shutil.copytree(src_banner_data_path, dst_banner_data_path)

            #  user stuff
            user_resource = self._get_task_resource_by_params(self.Parameters.user_model_resource_name, task_id, int(model_id))
            dst_user_data_dir = os.path.join(dst_resource_info['user']['root'], str(model_id))
            dst_user_data_path = os.path.join(dst_user_data_dir, 'UserNamespaces')
            os.makedirs(dst_user_data_dir)

            self._write_sb_task_id_2_file(task_id, os.path.join(dst_user_data_dir, 'sandbox_task_id'))

            src_user_data_path = self._get_resource_data_path(user_resource)
            shutil.copytree(src_user_data_path, dst_user_data_path)

            # optional tensor stuff
            try:
                tensor_resource = self._get_task_resource_by_params(self.Parameters.tensor_model_resource_name, task_id, int(model_id))
                dst_tensor_data_path = os.path.join(dst_user_data_dir, 'tensor_model')
                src_tensor_data_path = self._get_resource_data_path(tensor_resource)
                shutil.copytree(src_tensor_data_path, dst_tensor_data_path)
            except:
                logging.info('There was no tensor for model {}, we continue without it'.format(model_id))

        self._write_tasks_info_2_yt(sb_models_info, yt=yt_client)

        for entity in dst_resource_info:
            dst_resource_info[entity]['resource_data'].ready()

    def on_release(self, params):
        nanny.ReleaseToNannyTask2.on_release(self, params)
        nanny_client = NannyClient(
            api_url=const.NANNY_API_URL,
            oauth_token=sdk2.Vault.data('adv_machine_nanny_token'),
        )
        data = defaultdict(dict, self.get_nanny_release_info(params))
        data['meta']['labels'] = self.Parameters.release_labels
        result = nanny_client.create_release2(dict(data))['value']
        release_id = result['id']
        release_link = const.RELEASE_REQUEST_TEMPLATE.format(
            nanny_api_url=const.NANNY_API_URL,
            release_request_id=release_id,
        )
        self.set_info(
            'Nanny release request <a href="{}">{}</a> was created.'.format(release_link, release_id),
            do_escape=False
        )
