import logging
from copy import deepcopy

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types import task as task_type
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.projects.common.yabs.server.util.general import check_tasks
from sandbox.projects.yabs.qa.resource_types import YABS_SERVER_B2B_BINARY_BASE
from sandbox.projects.yabs.qa.tasks.YabsServerDownloadBases import YabsServerDownloadBases
from sandbox.projects.yabs.qa.utils.bstr import get_bstr_info


logger = logging.getLogger(__name__)
BASES_DIR = 'bases'


class YabsServerGetProdBases(sdk2.Task):
    """Get yabs_server binary bases from production"""

    class Requirements(sdk2.Requirements):
        ram = 4 * 1024
        cores = 1
        environments = (
            PipEnvironment('yandex-yt', use_wheel=True),
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        bases = sdk2.parameters.JSON('List of base names to download')
        base_ver = sdk2.parameters.String('Base version to download', default='0541333430')
        yt_path = sdk2.parameters.String('Path with bases bstr info', default='//home/yabs-transport/transport/bsfrontend_dir')
        yt_proxy = sdk2.parameters.String('YT cluster with bases info', default='locke')

        with sdk2.parameters.Group('Misc') as misc:
            desired_n_subtasks = sdk2.parameters.Integer('Count of tasks to use for download', default=12)
            fails_limit = sdk2.parameters.Integer('Fails limit', default=3)

        with sdk2.parameters.Output:
            base_resources = sdk2.parameters.JSON('Dict with bases')

    class Context(sdk2.Context):
        download_task_ids = []
        bases_fail_counter = {}
        already_done_bases = {}

    def _launch_download_tasks(self, bases, description='Download'):
        from yt.wrapper import YtClient
        yt_token = sdk2.Vault.data(self.owner, 'yabscs_yt_token')
        yt_client = YtClient(proxy=self.Parameters.yt_proxy, token=yt_token)
        bstr_info = get_bstr_info(bases, self.Parameters.base_ver, yt_client, self.Parameters.yt_path)

        missing_bases = set(bases) - set(bstr_info.keys())
        if missing_bases:
            logger.error('Did not find bases %s in %s %s', missing_bases, self.Parameters.yt_proxy, self.Parameters.yt_path)
            raise TaskFailure('Cannot find bases: {}'.format(','.join(missing_bases)))

        already_done_bases = {}
        bases_to_download = {}
        for base_tag, base_bstr_info in bstr_info.items():
            resource = YABS_SERVER_B2B_BINARY_BASE.find(state='READY', attrs={'tag': base_tag, 'torrent': base_bstr_info['torrent']}, limit=1).first()
            if resource:
                already_done_bases[base_tag] = resource.id
            else:
                bases_to_download[base_tag] = base_bstr_info

        chunks = []
        logger.info('Will schedule %d bases to download: %s', len(bases_to_download), list(bases_to_download.keys()))
        limit = max(
            sum([i['fsize'] for i in bstr_info.values()]) / self.Parameters.desired_n_subtasks,
            75 << 30,  # 75 GiB
        )
        for base_tag, base_bstr_info in bases_to_download.items():
            size = base_bstr_info['fsize']
            download_info = {k: base_bstr_info[k] for k in ('torrent', 'name', 'mtime')}
            for chunk in chunks:
                if chunk['size'] + size < limit:
                    chunk['items'][base_tag] = download_info
                    chunk['size'] += size
                    break
            else:
                chunks.append({
                    'size': size,
                    'items': {base_tag: download_info}
                })

        logger.info('Got %s chunks with limit %s GiB: %s', len(chunks), round(float(limit) / ((2 ** 10) ** 3), 3), chunks)

        download_task_ids = []
        decompression_ratio = 3
        for idx, chunk in enumerate(chunks):
            download_task = YabsServerDownloadBases(
                self,
                description='{} chunk#{}: {}'.format(
                    description,
                    idx + 1,
                    ', '.join(list(chunk['items'].keys()))
                ),
                tags=self.Parameters.tags,
                base_bstr_info=chunk['items'],
                __requirements__={
                    'tasks_resource': self.Requirements.tasks_resource,
                    'disk_space': (16 << 10) + (chunk['size'] >> 20) * (1 + decompression_ratio),
                },
            ).enqueue()
            download_task_ids.append(download_task.id)

        return download_task_ids, already_done_bases

    def on_execute(self):
        completely_failed_bases = []

        if not self.Context.download_task_ids:
            bases_to_schedule = self.Parameters.bases
            description = "Download"
        else:
            subtask_statuses = check_tasks(self, self.Context.download_task_ids, raise_on_fail=False)

            retry_count = 1
            bases_fail_counter = deepcopy(self.Context.bases_fail_counter)
            done_bases = {}
            for task, status in subtask_statuses:
                logger.info('Task #%d is in %s status', task.id, status)
                done_bases.update(task.Parameters.base_resources or {})
                if status in {task_type.Status.FAILURE, } | task_type.Status.Group.BREAK:
                    for base_tag in task.Parameters.failed_to_download_bases or task.Parameters.base_bstr_info.keys():
                        if bases_fail_counter.get(base_tag, 0) > self.Parameters.fails_limit:
                            completely_failed_bases.append(base_tag)
                        else:
                            bases_fail_counter.setdefault(base_tag, 0)
                            bases_fail_counter[base_tag] += 1
                            retry_count = max(retry_count, bases_fail_counter[base_tag])
                elif status != task_type.Status.SUCCESS:
                    logger.error('Unknown task status %s', status)

            self.Context.bases_fail_counter = bases_fail_counter
            self.Context.already_done_bases.update(done_bases)
            self.Context.save()
            description = "Retry download #{}".format(retry_count)

        if completely_failed_bases:
            raise TaskFailure('Bases {} exeeded their download failures count limit'.format(completely_failed_bases))

        bases_to_schedule = list(set(self.Parameters.bases) - set(self.Context.already_done_bases.keys()))

        if bases_to_schedule:
            self.Context.download_task_ids, already_done_bases = self._launch_download_tasks(bases_to_schedule, description=description)
            self.Context.already_done_bases.update(already_done_bases)
            self.Context.save()
            check_tasks(self, self.Context.download_task_ids, raise_on_fail=False)

        self.Parameters.base_resources = self.Context.already_done_bases
