# -*- coding: utf-8 -*-
import json
from sandbox import sdk2
from sandbox.projects.modadvert.common.modadvert import ModadvertBaseYtTask
from sandbox.projects.modadvert.RunLaaSMR import ModadvertRunLaaSMR
from sandbox.sandboxsdk.environments import PipEnvironment


class ModadvertRunCupidMR(ModadvertBaseYtTask):
    """
    Runs given binaries with Cupid MR runners.
    Splits content of given `src_tables` into several temporary tables according to objects' types.
    Then runs each MR runner and concatenates results of all completed runners' into `dst_table`
    """

    name = 'MODADVERT_RUN_CUPID_MR'

    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('yandex-yt', '0.9.9'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', '0.3.7.post1')
        )

    class Parameters(ModadvertBaseYtTask.Parameters):

        kill_timeout = 30 * 60 * 60  # 30h

        with sdk2.parameters.Group('Resources') as binary_group:
            mr_runners = sdk2.parameters.Dict('Named list of LaaS MR-runner binaries (in format object_type: resource_id)', required=True)

        with sdk2.parameters.Group('Tables') as tables_group:
            src_tables = sdk2.parameters.List('Path to YT tables with requests', required=True)
            dst_table = sdk2.parameters.String('Path to YT table where results will be stored', required=True)
            tmp_directory = sdk2.parameters.String('Path to directory where temporary tables will be stored', default=None)

        with sdk2.parameters.Group('YT run parameters') as yt_group:
            memory_limit = ModadvertRunLaaSMR.Parameters.memory_limit()
            job_count = ModadvertRunLaaSMR.Parameters.job_count()
            user_slots = ModadvertRunLaaSMR.Parameters.user_slots()
            max_failed_job_count = ModadvertRunLaaSMR.Parameters.max_failed_job_count()
            pool = ModadvertRunLaaSMR.Parameters.pool()
            mount_sandbox_in_tmpfs = ModadvertRunLaaSMR.Parameters.mount_sandbox_in_tmpfs()
            lock_attempts = ModadvertRunLaaSMR.Parameters.lock_attempts()

    def _create_runner_subtask(self, object_type):
        return self.create_subtask(
            ModadvertRunLaaSMR,
            {
                ModadvertRunLaaSMR.Parameters.yt_proxy_url.name: self.Parameters.yt_proxy_url,
                ModadvertRunLaaSMR.Parameters.tokens.name: self.Parameters.tokens,
                ModadvertRunLaaSMR.Parameters.vault_user.name: self.Parameters.vault_user,
                ModadvertRunLaaSMR.Parameters.binaries_resource.name: self.Parameters.mr_runners[object_type],
                ModadvertRunLaaSMR.Parameters.src_tables.name: [self.Context.src_tables_by_type[object_type]],
                ModadvertRunLaaSMR.Parameters.dst_table.name: self.Context.dst_tables_by_type[object_type],
                ModadvertRunLaaSMR.Parameters.memory_limit.name: self.Parameters.memory_limit,
                ModadvertRunLaaSMR.Parameters.job_count.name: self.Parameters.job_count,
                ModadvertRunLaaSMR.Parameters.user_slots.name: self.Parameters.user_slots,
                ModadvertRunLaaSMR.Parameters.max_failed_job_count.name: self.Parameters.max_failed_job_count,
                ModadvertRunLaaSMR.Parameters.pool.name: self.Parameters.pool,
                ModadvertRunLaaSMR.Parameters.mount_sandbox_in_tmpfs.name: self.Parameters.mount_sandbox_in_tmpfs,
                ModadvertRunLaaSMR.Parameters.lock_attempts.name: self.Parameters.lock_attempts
            },
            'Run MR runner for "{}"'.format(object_type)
        )

    def on_before_execute(self):
        super(ModadvertRunCupidMR, self).on_before_execute()

        if not self.Context.object_types:
            self.Context.object_types = list(self.Parameters.mr_runners.keys())
        if not self.Context.src_tables_by_type:
            from . import yt_runner
            call_result = json.loads(yt_runner.run(
                task='prepare_tables',
                yt_proxy_url=self.Parameters.yt_proxy_url,
                yt_token=self.get_yt_token(),
                src_tables=self.Parameters.src_tables,
                object_types=self.Context.object_types,
                expiration_timeout=self.Parameters.kill_timeout * 1000,  # from seconds to milliseconds
                tmp_directory=self.Parameters.tmp_directory
            ))
            self.Context.src_tables_by_type = {
                object_type: table
                for object_type, table in zip(self.Context.object_types, call_result['src_tables'])
            }
            self.Context.dst_tables_by_type = {
                object_type: table
                for object_type, table in zip(self.Context.object_types, call_result['dst_tables'])
            }

    def on_execute_inner(self):
        super(ModadvertRunCupidMR, self).on_execute_inner()

        for object_type, src_table, dst_table in zip(self.Context.object_types, self.Context.src_tables_by_type, self.Context.dst_tables_by_type):
            subtask_key = '{}_runner_subtask'.format(object_type)
            if not getattr(self.Context, subtask_key):
                setattr(self.Context, subtask_key, self._create_runner_subtask(object_type))
        self.wait_all_subtasks()

    def on_after_execute(self):
        super(ModadvertRunCupidMR, self).on_after_execute()

        if not self.Context.results_are_concatenated:
            from . import yt_runner
            yt_runner.run(
                task='concatenate_results',
                yt_proxy_url=self.Parameters.yt_proxy_url,
                yt_token=self.get_yt_token(),
                result_tables=list(self.Context.dst_tables_by_type.values()),
                dst_table=self.Parameters.dst_table
            )
            self.Context.results_are_concatenated = True
