# -*- coding: utf-8 -*-

from sandbox import sdk2
from sandbox import common
import random
from sandbox.common.types import task
from sandbox.projects.common.search.database import iss_shards
from sandbox.projects.judtier import shards
from sandbox.projects.judtier import get_dups_table
from sandbox.projects.judtier import import_shard
from sandbox.sandboxsdk import environments


class JudTierImportShardSet(sdk2.Task):
    """
    Выкачивает все данные JudTier, необходимые для работы MAKE_ESTIMATED_POOL,
    для отдельно взятого стейта Юпитера
    """

    class Parameters(sdk2.Task.Parameters):
        jupiter_state = sdk2.parameters.String('Jupiter state to work with', required=True, description="Юпитерный стейт или алиас, e.g. 20161231-123456 или production_current_state")
        add_testenv_prod_attrs = sdk2.parameters.Bool('Add TestEnv prod attributes', default=False, description="Выбрать случайный выкачанный шард и пометить его атрибутом автоапдейта TestEnv")

    class Requirements(sdk2.Task.Requirements):
        cores = 1

        environments = (environments.PipEnvironment('yandex-yt', use_wheel=True),)

        class Caches(sdk2.Requirements.Caches):
            pass

    class Context(sdk2.Task.Context):
        state = ''
        shards = []
        children_ids = []

    def on_execute(self):
        if self.Context.children_ids:
            for child_id in self.Context.children_ids:
                child = sdk2.Task[child_id]
                if child.status in task.Status.Group.BREAK:
                    raise common.errors.TaskStop("Child has completed with status {}".format(child.status))
                elif child.status != task.Status.SUCCESS:
                    raise common.errors.TaskFailure("Child has completed with status {}".format(child.status))
            self._postprocess()
        else:
            self.Context.children_ids = self._create_children()
            raise sdk2.WaitTask(self.Context.children_ids, task.Status.Group.FINISH | task.Status.Group.BREAK)

    def _resolve_state(self, state, yt_wrapper):
        yt_client = yt_wrapper.YtClient(proxy="arnold", token=yt_wrapper.config["token"])
        metadata = yt_client.get("//home/jupiter/@jupiter_meta")
        if state not in metadata:
            raise common.errors.TaskFailure("Failed to resolve Jupiter state alias " + state)
        self.set_info("Jupiter state {} resolved to {}".format(state, metadata[state]))
        return metadata[state]

    def _create_children(self):
        import yt.wrapper as yt
        yt.config["token"] = sdk2.Vault.data('RATED-POOL-MAKERS', 'RATED_POOLS_YT_TOKEN')
        used_state = self.Parameters.jupiter_state
        if not used_state[0].isdigit():
            used_state = self._resolve_state(used_state, yt)
        self.Context.state = used_state
        shards_count = shards.detect_shards_count(self, used_state, yt)
        shard_timestamp = shards.state_to_timestamp(used_state)

        children = []
        dups_subtask = get_dups_table.JudTierGetDupsTable(
            self,
            description=self.Parameters.description+", get dups table",
            owner=self.Parameters.owner,
            priority=task.Priority(task.Priority.Class.SERVICE, task.Priority.Subclass.NORMAL),
            jupiter_state=used_state,
            split_shards=True,
            precreate_count=shards_count).enqueue()
        children.append(dups_subtask.id)

        shards_list = []
        for shard_num in range(shards_count):
            shard_name = "primus-JudTier-{}-{}-{}".format(0, shard_num, shard_timestamp)
            shard_rbtorrent = iss_shards.get_shard_name(shard_name)
            shard_subtask = import_shard.JudTierImportShard(
                self,
                description=self.Parameters.description+", get shard "+shard_name,
                owner=self.Parameters.owner,
                priority=task.Priority(task.Priority.Class.SERVICE, task.Priority.Subclass.NORMAL),
                search_database_rsync_path=shard_rbtorrent,
                database_shard_name=shard_name,
            )
            shard_subtask.enqueue()
            children.append(shard_subtask.id)
            shards_list.append(shard_name)

        self.Context.shards = shards_list
        return children

    def _postprocess(self):
        if self.Parameters.add_testenv_prod_attrs:
            selected_shard = random.choice(self.Context.shards)
            shard_resource = sdk2.Resource["SEARCH_DATABASE"].find(state='READY', attrs=dict(shard_instance=selected_shard)).first()
            shard_resource.TE_prod_JudTier = self.Context.state
            dups_resource = sdk2.Resource["JUDTIER_DUPS_TABLE"].find(state='READY', attrs=dict(shard_name=selected_shard)).first()
            dups_resource.TE_prod_JudTier_dups = self.Context.state
