# -*- coding: utf-8 -*-

from sandbox import common
from sandbox import sdk2
from sandbox.common.errors import TaskFailure
import sandbox.common.types.task as ctt

from sandbox.projects.yabs.base_bin_task import BaseBinTask

from sandbox.projects.autobudget.back_to_back.lib.stand import YabsOldAutobudgetBackToBackStand
from sandbox.projects.autobudget.back_to_back.lib.table_uploader import YabsAutobudgetTableUploader


class YabsAutobudgetTableUploadCoordinator(BaseBinTask):
    class Parameters(BaseBinTask.Parameters):
        stand_task = sdk2.parameters.Task(
            "Stand task",
            task_type=YabsOldAutobudgetBackToBackStand,
            required=True,
        )

        with sdk2.parameters.Group("Infrastructure parameters"):
            resource_attrs = sdk2.parameters.Dict(
                "Filter resource by",
                default={"name": "autobudget-back-to-back-binary"},
                description="Will be passed to 'attrs' search parameter",
            )

    class Requirements(BaseBinTask.Requirements):
        cores = 1
        ram = 8192

        class Caches(sdk2.Requirements.Caches):
            pass

    def on_execute(self):
        stand_task = self.Parameters.stand_task

        with self.memoize_stage.create_subtasks(commit_on_entrance=False):
            subtask_ids = []
            for dump_resource_id in stand_task.Parameters.dumped_tables:
                subtask_ids.append(
                    YabsAutobudgetTableUploader(
                        self,
                        description="Upload dumped table to YT",
                        notifications=[],
                        dump_to_upload=dump_resource_id,
                    ).id,
                )
            self.Context.subtask_ids = subtask_ids

        subtask_ids = self.Context.subtask_ids
        with self.memoize_stage.schedule_subtasks(commit_on_entrance=False):
            for task_id in subtask_ids:
                sdk2.Task[task_id].enqueue()

            raise sdk2.WaitTask(
                subtask_ids,
                common.utils.chain(ctt.Status.Group.FINISH, ctt.Status.Group.BREAK),
                wait_all=True,
            )

        def succeed(task):
            return task.status == ctt.Status.SUCCESS

        if not all(succeed(sdk2.Task[subtask_id]) for subtask_id in subtask_ids):
            raise TaskFailure("Some of uploads failed")
