# -*- coding: utf-8 -*-
import logging
import os
import shutil

from sandbox import sdk2

from sandbox.projects.ads.tsar.lib import make_difacto_dump

from .lib import resources  # noqa

logger = logging.getLogger(__name__)


class TsarDumpsCollector(sdk2.Task):
    """Collects dumps from several task_ids"""
    class Requirements(sdk2.Task.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        task_ids_sthash_float = sdk2.parameters.List("List of task ids with dmlc dumps should be converted to float sthash")
        task_ids_vinyl_float = sdk2.parameters.List("List of task ids with dmlc dumps should be converted to float vinyl")
        task_ids_sthash_uint16 = sdk2.parameters.List("List of task ids with dmlc dumps should be converted to uint16 sthash")
        task_ids_vinyl_uint16 = sdk2.parameters.List("List of task ids with dmlc dumps should be converted to uint16 vinyl")
        dumps_dir_name = sdk2.parameters.String("Name of directory with dumps", default_value="dmlc_dumps")

    def on_execute(self):
        all_task_ids = \
            set(self.Parameters.task_ids_sthash_float) |\
            set(self.Parameters.task_ids_vinyl_float) |\
            set(self.Parameters.task_ids_sthash_uint16) |\
            set(self.Parameters.task_ids_vinyl_uint16)

        if len(all_task_ids) == 0:
            logger.info("No tasks ids provided - nothing to do")
            return
        if os.path.isdir(self.Parameters.dumps_dir_name):
            shutil.rmtree(self.Parameters.dumps_dir_name)

        dumps_list_resource = resources.TsarDumpsList(self, "{} dmlc dumps".format(len(all_task_ids)), self.Parameters.dumps_dir_name, ttl=30)
        resource_data = sdk2.ResourceData(dumps_list_resource)
        resource_data.path.mkdir(0o755, parents=True, exist_ok=True)

        for (task_ids, compression_type, model_format) in [
            (self.Parameters.task_ids_sthash_float, 'float', 'sthash'),
            (self.Parameters.task_ids_vinyl_float, 'float', 'vinyl'),
            (self.Parameters.task_ids_sthash_uint16, 'uint16', 'sthash'),
            (self.Parameters.task_ids_vinyl_uint16, 'uint16', 'vinyl'),
        ]:
            for task_id in set(task_ids):
                last_tsar_dump_path = make_difacto_dump(task_id, compression_type, model_format)
                if compression_type == 'float' and model_format == 'sthash':
                    dst_path = os.path.join(str(resource_data.path), task_id)
                else:
                    dst_path = os.path.join(str(resource_data.path), task_id + '_' + model_format + '_' + compression_type)
                logger.info("Copying dump '{}' from {} to {}".format(task_id, last_tsar_dump_path, dst_path))
                shutil.copy2(last_tsar_dump_path, dst_path)

        resource_data.ready()
