# -*- coding: utf-8 -*-

from sandbox import sdk2
from sandbox import common
import time
import logging
from datetime import datetime
from sandbox.projects import resource_types
from sandbox.sdk2.helpers import subprocess as sp
from sandbox.common.types import task
from sandbox.sandboxsdk import environments


class LbExtractShard(sdk2.Task):
    """
        Извлекает данные, нужные для работы YT-стенда лингвобустинга,
        из слепка Юпитерной базы JudTier.
    """
    class Parameters(sdk2.Task.Parameters):
        worker_binary = sdk2.parameters.Resource(
            "extract_lb_shard binary",
            resource_type=resource_types.ARCADIA_PROJECT,
            required=True,
            __doc__="Create this with YA_MAKE task for quality/relev_tools/lboost_ops/mr_index/extract_lb_shard")

        jupiter_state = sdk2.parameters.String(
            "Jupiter state to work with",
            __doc__="Set to extract data from all JudTier shards from this state")

        shard_resource = sdk2.parameters.Resource(
            "Search database",
            resource_type=resource_types.SEARCH_DATABASE,
            __doc__="Set to extract data from one shard")

        yt_server = sdk2.parameters.String(
            "YT server",
            default="hahn.yt.yandex.net",
            required=True)

        yt_vault_owner = sdk2.parameters.String(
            "Owner of vault item with YT token",
            default='RATED-POOL-MAKERS',
            required=True,
            __doc__="Owner from https://sandbox.yandex-team.ru/admin/vault")

        yt_vault_name = sdk2.parameters.String(
            "Name of vault item with YT token",
            default='RATED_POOLS_YT_TOKEN',
            required=True,
            __doc__="Name from https://sandbox.yandex-team.ru/admin/vault")

        input_pool = sdk2.parameters.String(
            "Input YT table with features",
            required=True,
            __doc__="Main protopool table")

        with sdk2.parameters.RadioGroup("Match mode", required=True, __doc__="How to match documents in pool and shard") as match_mode:
            match_mode.values.docid = match_mode.Value("By docid")
            match_mode.values.robot_url = match_mode.Value("By robot url")

        max_docs = sdk2.parameters.Integer(
            "Max documents per shard",
            default=2000,
            __doc__="Split shards to subshards of given size")

        output_indices = sdk2.parameters.String(
            "Output YT table with indices",
            required=True,
            __doc__="Core table for lingboosting experiments")

        output_pool = sdk2.parameters.String(
            "Output YT table with features",
            __doc__="Copy of input protopool with replaced Shard+DocId")

    class Requirements(sdk2.Requirements):
        cores = 1

        environments = (environments.PipEnvironment('yandex-yt', use_wheel=True),)

        class Caches(sdk2.Requirements.Caches):
            pass

    class Context(sdk2.Task.Context):
        children_ids = []

    def on_execute(self):
        if self.Parameters.shard_resource:
            self._process_single_shard()
        else:
            if self.Context.children_ids:
                for child_id in self.Context.children_ids:
                    child = sdk2.Task[child_id]
                    if child.status != task.Status.SUCCESS:
                        raise common.errors.TaskFailure("Child has completed with status {}".format(child.status))
            else:
                self._validate_output_absent()
                self.Context.children_ids = self._create_children()
                raise sdk2.WaitTask(self.Context.children_ids, task.Status.Group.FINISH)

    def _create_children(self):
        timestamp = int(time.mktime(
            datetime.strptime(self.Parameters.jupiter_state, '%Y%m%d-%H%M%S').timetuple()
        ))
        children = []
        for shard_num in range(1000):
            shard_name = "primus-JudTier-{}-{}-{}".format(0, shard_num, timestamp)
            shard_resource = sdk2.Resource.find(resource_types.SEARCH_DATABASE, attrs={'shard_instance': shard_name}).first()
            if shard_resource is None:
                break
            subtask = LbExtractShard(
                self,
                description=self.Parameters.description+", shard "+shard_name,
                owner=self.Parameters.owner,
                priority=task.Priority(task.Priority.Class.SERVICE, task.Priority.Subclass.NORMAL),
                worker_binary=self.Parameters.worker_binary,
                shard_resource=shard_resource,
                yt_server=self.Parameters.yt_server,
                yt_vault_owner=self.Parameters.yt_vault_owner,
                yt_vault_name=self.Parameters.yt_vault_name,
                input_pool=self.Parameters.input_pool,
                match_mode=self.Parameters.match_mode,
                max_docs=self.Parameters.max_docs,
                output_indices=self.Parameters.output_indices,
                output_pool=self.Parameters.output_pool)
            # Try to prevent YT error 'Limit for the number of concurrent operations 50 for pool "robot-pool-collector" in tree "physical" has been reached':
            # limit number of concurrently running tasks for the pair (server, yt token),
            # yt token is identified by (yt_vault_owner, yt_vault_name).
            # Just in case, make tasks limit less than YT limit.
            semaphor_name = "{}_{}_{}".format(self.Parameters.yt_vault_owner, self.Parameters.yt_vault_name, self.Parameters.yt_server)
            subtask.Requirements.semaphores = task.Semaphores(
                acquires=[task.Semaphores.Acquire(name=semaphor_name, weight=1, capacity=40)],
                release=(task.Status.Group.BREAK, task.Status.Group.FINISH, task.Status.Group.WAIT)
            )
            subtask.save()
            sdk2.Task.server.task[subtask.id].update({'requirements': {'disk_space': 45 << 30}})
            subtask.enqueue()
            children.append(subtask.id)
        if not children:
            raise common.errors.TaskFailure("No JudTier shards found for given state")
        return children

    def _get_yt_token(self):
        return sdk2.Vault.data(self.Parameters.yt_vault_owner, self.Parameters.yt_vault_name)

    def _validate_output_absent(self):
        import yt.wrapper as yt
        yt_client = yt.YtClient(proxy=self.Parameters.yt_server, token=self._get_yt_token())
        yt_client.create("table", self.Parameters.output_indices, recursive=True, attributes={"compression_codec": "brotli_3", "erasure_codec": "lrc_12_2_2"})
        if self.Parameters.output_pool:
            yt_client.create("table", self.Parameters.output_pool, recursive=True, attributes={"compression_codec": "brotli_5", "erasure_codec": "lrc_12_2_2"})

    def _process_single_shard(self):
        worker_binary_path = str(sdk2.ResourceData(self.Parameters.worker_binary).path)
        shard_path = str(sdk2.ResourceData(self.Parameters.shard_resource).path)
        command = [
            worker_binary_path,
            "--match-mode", {"docid": "docid", "robot_url": "robot-url"}[self.Parameters.match_mode],
            "--server", self.Parameters.yt_server,
            "--pool-table", self.Parameters.input_pool,
            "--shard", shard_path,
            "--indices-table", self.Parameters.output_indices
        ]
        if self.Parameters.output_pool:
            command += ["--output-table", self.Parameters.output_pool]
        if self.Parameters.max_docs:
            command += ["--max-docs", str(self.Parameters.max_docs)]
        with sdk2.helpers.ProcessLog(self, logger=logging.getLogger("extract_lb_shard")) as pl:
            sp.check_call(command, env={"YT_TOKEN": self._get_yt_token()}, stdout=pl.stdout, stderr=sp.STDOUT)
